Log list of tool calls in ActionStep (#172)
This commit is contained in:
		
							parent
							
								
									5a62304c91
								
							
						
					
					
						commit
						289c06df0f
					
				|  | @ -82,7 +82,7 @@ class AgentStep: | ||||||
| @dataclass | @dataclass | ||||||
| class ActionStep(AgentStep): | class ActionStep(AgentStep): | ||||||
|     agent_memory: List[Dict[str, str]] | None = None |     agent_memory: List[Dict[str, str]] | None = None | ||||||
|     tool_call: ToolCall | None = None |     tool_calls: List[ToolCall] | None = None | ||||||
|     start_time: float | None = None |     start_time: float | None = None | ||||||
|     end_time: float | None = None |     end_time: float | None = None | ||||||
|     step: int | None = None |     step: int | None = None | ||||||
|  | @ -302,25 +302,26 @@ class MultiStepAgent: | ||||||
|                     } |                     } | ||||||
|                     memory.append(thought_message) |                     memory.append(thought_message) | ||||||
| 
 | 
 | ||||||
|                 if step_log.tool_call is not None: |                 if step_log.tool_calls is not None: | ||||||
|                     tool_call_message = { |                     tool_call_message = { | ||||||
|                         "role": MessageRole.ASSISTANT, |                         "role": MessageRole.ASSISTANT, | ||||||
|                         "content": str( |                         "content": str( | ||||||
|                             [ |                             [ | ||||||
|                                 { |                                 { | ||||||
|                                     "id": step_log.tool_call.id, |                                     "id": tool_call.id, | ||||||
|                                     "type": "function", |                                     "type": "function", | ||||||
|                                     "function": { |                                     "function": { | ||||||
|                                         "name": step_log.tool_call.name, |                                         "name": tool_call.name, | ||||||
|                                         "arguments": step_log.tool_call.arguments, |                                         "arguments": tool_call.arguments, | ||||||
|                                     }, |                                     }, | ||||||
|                                 } |                                 } | ||||||
|  |                                 for tool_call in step_log.tool_calls | ||||||
|                             ] |                             ] | ||||||
|                         ), |                         ), | ||||||
|                     } |                     } | ||||||
|                     memory.append(tool_call_message) |                     memory.append(tool_call_message) | ||||||
| 
 | 
 | ||||||
|                 if step_log.tool_call is None and step_log.error is not None: |                 if step_log.tool_calls is None and step_log.error is not None: | ||||||
|                     message_content = ( |                     message_content = ( | ||||||
|                         "Error:\n" |                         "Error:\n" | ||||||
|                         + str(step_log.error) |                         + str(step_log.error) | ||||||
|  | @ -330,7 +331,7 @@ class MultiStepAgent: | ||||||
|                         "role": MessageRole.ASSISTANT, |                         "role": MessageRole.ASSISTANT, | ||||||
|                         "content": message_content, |                         "content": message_content, | ||||||
|                     } |                     } | ||||||
|                 if step_log.tool_call is not None and ( |                 if step_log.tool_calls is not None and ( | ||||||
|                     step_log.error is not None or step_log.observations is not None |                     step_log.error is not None or step_log.observations is not None | ||||||
|                 ): |                 ): | ||||||
|                     if step_log.error is not None: |                     if step_log.error is not None: | ||||||
|  | @ -343,7 +344,7 @@ class MultiStepAgent: | ||||||
|                         message_content = f"Observation:\n{step_log.observations}" |                         message_content = f"Observation:\n{step_log.observations}" | ||||||
|                     tool_response_message = { |                     tool_response_message = { | ||||||
|                         "role": MessageRole.TOOL_RESPONSE, |                         "role": MessageRole.TOOL_RESPONSE, | ||||||
|                         "content": f"Call id: {(step_log.tool_call.id if getattr(step_log.tool_call, 'id') else 'call_0')}\n" |                         "content": f"Call id: {(step_log.tool_calls[0].id if getattr(step_log.tool_calls[0], 'id') else 'call_0')}\n" | ||||||
|                         + message_content, |                         + message_content, | ||||||
|                     } |                     } | ||||||
|                     memory.append(tool_response_message) |                     memory.append(tool_response_message) | ||||||
|  | @ -814,9 +815,9 @@ class ToolCallingAgent(MultiStepAgent): | ||||||
|                 f"Error in generating tool call with model:\n{e}" |                 f"Error in generating tool call with model:\n{e}" | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|         log_entry.tool_call = ToolCall( |         log_entry.tool_calls = [ | ||||||
|             name=tool_name, arguments=tool_arguments, id=tool_call_id |             ToolCall(name=tool_name, arguments=tool_arguments, id=tool_call_id) | ||||||
|         ) |         ] | ||||||
| 
 | 
 | ||||||
|         # Execute |         # Execute | ||||||
|         self.logger.log( |         self.logger.log( | ||||||
|  | @ -977,11 +978,13 @@ class CodeAgent(MultiStepAgent): | ||||||
|             ) |             ) | ||||||
|             raise AgentParsingError(error_msg) |             raise AgentParsingError(error_msg) | ||||||
| 
 | 
 | ||||||
|         log_entry.tool_call = ToolCall( |         log_entry.tool_calls = [ | ||||||
|             name="python_interpreter", |             ToolCall( | ||||||
|             arguments=code_action, |                 name="python_interpreter", | ||||||
|             id=f"call_{len(self.logs)}", |                 arguments=code_action, | ||||||
|         ) |                 id=f"call_{len(self.logs)}", | ||||||
|  |             ) | ||||||
|  |         ] | ||||||
| 
 | 
 | ||||||
|         # Execute |         # Execute | ||||||
|         self.logger.log( |         self.logger.log( | ||||||
|  |  | ||||||
|  | @ -43,9 +43,7 @@ class Monitor: | ||||||
|     def update_metrics(self, step_log): |     def update_metrics(self, step_log): | ||||||
|         step_duration = step_log.duration |         step_duration = step_log.duration | ||||||
|         self.step_durations.append(step_duration) |         self.step_durations.append(step_duration) | ||||||
|         console_outputs = ( |         console_outputs = f"[Step {len(self.step_durations) - 1}: Duration {step_duration:.2f} seconds" | ||||||
|             f"[Step {len(self.step_durations)-1}: Duration {step_duration:.2f} seconds" |  | ||||||
|         ) |  | ||||||
| 
 | 
 | ||||||
|         if getattr(self.tracked_model, "last_input_token_count", None) is not None: |         if getattr(self.tracked_model, "last_input_token_count", None) is not None: | ||||||
|             self.total_input_token_count += self.tracked_model.last_input_token_count |             self.total_input_token_count += self.tracked_model.last_input_token_count | ||||||
|  |  | ||||||
|  | @ -179,12 +179,12 @@ class Tool: | ||||||
|                     f"Attribute {attr} should have type {expected_type.__name__}, got {type(attr_value)} instead." |                     f"Attribute {attr} should have type {expected_type.__name__}, got {type(attr_value)} instead." | ||||||
|                 ) |                 ) | ||||||
|         for input_name, input_content in self.inputs.items(): |         for input_name, input_content in self.inputs.items(): | ||||||
|             assert isinstance( |             assert isinstance(input_content, dict), ( | ||||||
|                 input_content, dict |                 f"Input '{input_name}' should be a dictionary." | ||||||
|             ), f"Input '{input_name}' should be a dictionary." |             ) | ||||||
|             assert ( |             assert "type" in input_content and "description" in input_content, ( | ||||||
|                 "type" in input_content and "description" in input_content |                 f"Input '{input_name}' should have keys 'type' and 'description', has only {list(input_content.keys())}." | ||||||
|             ), f"Input '{input_name}' should have keys 'type' and 'description', has only {list(input_content.keys())}." |             ) | ||||||
|             if input_content["type"] not in AUTHORIZED_TYPES: |             if input_content["type"] not in AUTHORIZED_TYPES: | ||||||
|                 raise Exception( |                 raise Exception( | ||||||
|                     f"Input '{input_name}': type '{input_content['type']}' is not an authorized value, should be one of {AUTHORIZED_TYPES}." |                     f"Input '{input_name}': type '{input_content['type']}' is not an authorized value, should be one of {AUTHORIZED_TYPES}." | ||||||
|  | @ -207,13 +207,13 @@ class Tool: | ||||||
|             json_schema = _convert_type_hints_to_json_schema(self.forward) |             json_schema = _convert_type_hints_to_json_schema(self.forward) | ||||||
|             for key, value in self.inputs.items(): |             for key, value in self.inputs.items(): | ||||||
|                 if "nullable" in value: |                 if "nullable" in value: | ||||||
|                     assert ( |                     assert key in json_schema and "nullable" in json_schema[key], ( | ||||||
|                         key in json_schema and "nullable" in json_schema[key] |                         f"Nullable argument '{key}' in inputs should have key 'nullable' set to True in function signature." | ||||||
|                     ), f"Nullable argument '{key}' in inputs should have key 'nullable' set to True in function signature." |                     ) | ||||||
|                 if key in json_schema and "nullable" in json_schema[key]: |                 if key in json_schema and "nullable" in json_schema[key]: | ||||||
|                     assert ( |                     assert "nullable" in value, ( | ||||||
|                         "nullable" in value |                         f"Nullable argument '{key}' in function signature should have key 'nullable' set to True in inputs." | ||||||
|                     ), f"Nullable argument '{key}' in function signature should have key 'nullable' set to True in inputs." |                     ) | ||||||
| 
 | 
 | ||||||
|     def forward(self, *args, **kwargs): |     def forward(self, *args, **kwargs): | ||||||
|         return NotImplementedError("Write this method in your subclass of `Tool`.") |         return NotImplementedError("Write this method in your subclass of `Tool`.") | ||||||
|  | @ -272,7 +272,7 @@ class Tool: | ||||||
|             class {class_name}(Tool): |             class {class_name}(Tool): | ||||||
|                 name = "{self.name}" |                 name = "{self.name}" | ||||||
|                 description = "{self.description}" |                 description = "{self.description}" | ||||||
|                 inputs = {json.dumps(self.inputs, separators=(',', ':'))} |                 inputs = {json.dumps(self.inputs, separators=(",", ":"))} | ||||||
|                 output_type = "{self.output_type}" |                 output_type = "{self.output_type}" | ||||||
|             """).strip() |             """).strip() | ||||||
|             import re |             import re | ||||||
|  | @ -439,7 +439,9 @@ class Tool: | ||||||
|                 `cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the |                 `cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the | ||||||
|                 others will be passed along to its init. |                 others will be passed along to its init. | ||||||
|         """ |         """ | ||||||
|         assert trust_remote_code, "Loading a tool from Hub requires to trust remote code. Make sure you've inspected the repo and pass `trust_remote_code=True` to load the tool." |         assert trust_remote_code, ( | ||||||
|  |             "Loading a tool from Hub requires to trust remote code. Make sure you've inspected the repo and pass `trust_remote_code=True` to load the tool." | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|         hub_kwargs_names = [ |         hub_kwargs_names = [ | ||||||
|             "cache_dir", |             "cache_dir", | ||||||
|  | @ -620,13 +622,10 @@ class Tool: | ||||||
|                     arg.save(temp_file.name) |                     arg.save(temp_file.name) | ||||||
|                     arg = temp_file.name |                     arg = temp_file.name | ||||||
|                 if ( |                 if ( | ||||||
|                     isinstance(arg, str) |                     (isinstance(arg, str) and os.path.isfile(arg)) | ||||||
|                     and os.path.isfile(arg) |                     or (isinstance(arg, Path) and arg.exists() and arg.is_file()) | ||||||
|                 ) or ( |                     or is_http_url_like(arg) | ||||||
|                     isinstance(arg, Path) |                 ): | ||||||
|                     and arg.exists() |  | ||||||
|                     and arg.is_file() |  | ||||||
|                 ) or is_http_url_like(arg): |  | ||||||
|                     arg = handle_file(arg) |                     arg = handle_file(arg) | ||||||
|                 return arg |                 return arg | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -99,7 +99,7 @@ def parse_json_blob(json_blob: str) -> Dict[str, str]: | ||||||
|         raise ValueError( |         raise ValueError( | ||||||
|             f"The JSON blob you used is invalid due to the following error: {e}.\n" |             f"The JSON blob you used is invalid due to the following error: {e}.\n" | ||||||
|             f"JSON blob was: {json_blob}, decoding failed on that specific part of the blob:\n" |             f"JSON blob was: {json_blob}, decoding failed on that specific part of the blob:\n" | ||||||
|             f"'{json_blob[place-4:place+5]}'." |             f"'{json_blob[place - 4 : place + 5]}'." | ||||||
|         ) |         ) | ||||||
|     except Exception as e: |     except Exception as e: | ||||||
|         raise ValueError(f"Error in parsing the JSON blob: {e}") |         raise ValueError(f"Error in parsing the JSON blob: {e}") | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue