Log list of tool calls in ActionStep (#172)
This commit is contained in:
parent
5a62304c91
commit
289c06df0f
|
@ -82,7 +82,7 @@ class AgentStep:
|
|||
@dataclass
|
||||
class ActionStep(AgentStep):
|
||||
agent_memory: List[Dict[str, str]] | None = None
|
||||
tool_call: ToolCall | None = None
|
||||
tool_calls: List[ToolCall] | None = None
|
||||
start_time: float | None = None
|
||||
end_time: float | None = None
|
||||
step: int | None = None
|
||||
|
@ -302,25 +302,26 @@ class MultiStepAgent:
|
|||
}
|
||||
memory.append(thought_message)
|
||||
|
||||
if step_log.tool_call is not None:
|
||||
if step_log.tool_calls is not None:
|
||||
tool_call_message = {
|
||||
"role": MessageRole.ASSISTANT,
|
||||
"content": str(
|
||||
[
|
||||
{
|
||||
"id": step_log.tool_call.id,
|
||||
"id": tool_call.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": step_log.tool_call.name,
|
||||
"arguments": step_log.tool_call.arguments,
|
||||
"name": tool_call.name,
|
||||
"arguments": tool_call.arguments,
|
||||
},
|
||||
}
|
||||
for tool_call in step_log.tool_calls
|
||||
]
|
||||
),
|
||||
}
|
||||
memory.append(tool_call_message)
|
||||
|
||||
if step_log.tool_call is None and step_log.error is not None:
|
||||
if step_log.tool_calls is None and step_log.error is not None:
|
||||
message_content = (
|
||||
"Error:\n"
|
||||
+ str(step_log.error)
|
||||
|
@ -330,7 +331,7 @@ class MultiStepAgent:
|
|||
"role": MessageRole.ASSISTANT,
|
||||
"content": message_content,
|
||||
}
|
||||
if step_log.tool_call is not None and (
|
||||
if step_log.tool_calls is not None and (
|
||||
step_log.error is not None or step_log.observations is not None
|
||||
):
|
||||
if step_log.error is not None:
|
||||
|
@ -343,7 +344,7 @@ class MultiStepAgent:
|
|||
message_content = f"Observation:\n{step_log.observations}"
|
||||
tool_response_message = {
|
||||
"role": MessageRole.TOOL_RESPONSE,
|
||||
"content": f"Call id: {(step_log.tool_call.id if getattr(step_log.tool_call, 'id') else 'call_0')}\n"
|
||||
"content": f"Call id: {(step_log.tool_calls[0].id if getattr(step_log.tool_calls[0], 'id') else 'call_0')}\n"
|
||||
+ message_content,
|
||||
}
|
||||
memory.append(tool_response_message)
|
||||
|
@ -814,9 +815,9 @@ class ToolCallingAgent(MultiStepAgent):
|
|||
f"Error in generating tool call with model:\n{e}"
|
||||
)
|
||||
|
||||
log_entry.tool_call = ToolCall(
|
||||
name=tool_name, arguments=tool_arguments, id=tool_call_id
|
||||
)
|
||||
log_entry.tool_calls = [
|
||||
ToolCall(name=tool_name, arguments=tool_arguments, id=tool_call_id)
|
||||
]
|
||||
|
||||
# Execute
|
||||
self.logger.log(
|
||||
|
@ -977,11 +978,13 @@ class CodeAgent(MultiStepAgent):
|
|||
)
|
||||
raise AgentParsingError(error_msg)
|
||||
|
||||
log_entry.tool_call = ToolCall(
|
||||
log_entry.tool_calls = [
|
||||
ToolCall(
|
||||
name="python_interpreter",
|
||||
arguments=code_action,
|
||||
id=f"call_{len(self.logs)}",
|
||||
)
|
||||
]
|
||||
|
||||
# Execute
|
||||
self.logger.log(
|
||||
|
|
|
@ -43,9 +43,7 @@ class Monitor:
|
|||
def update_metrics(self, step_log):
|
||||
step_duration = step_log.duration
|
||||
self.step_durations.append(step_duration)
|
||||
console_outputs = (
|
||||
f"[Step {len(self.step_durations)-1}: Duration {step_duration:.2f} seconds"
|
||||
)
|
||||
console_outputs = f"[Step {len(self.step_durations) - 1}: Duration {step_duration:.2f} seconds"
|
||||
|
||||
if getattr(self.tracked_model, "last_input_token_count", None) is not None:
|
||||
self.total_input_token_count += self.tracked_model.last_input_token_count
|
||||
|
|
|
@ -179,12 +179,12 @@ class Tool:
|
|||
f"Attribute {attr} should have type {expected_type.__name__}, got {type(attr_value)} instead."
|
||||
)
|
||||
for input_name, input_content in self.inputs.items():
|
||||
assert isinstance(
|
||||
input_content, dict
|
||||
), f"Input '{input_name}' should be a dictionary."
|
||||
assert (
|
||||
"type" in input_content and "description" in input_content
|
||||
), f"Input '{input_name}' should have keys 'type' and 'description', has only {list(input_content.keys())}."
|
||||
assert isinstance(input_content, dict), (
|
||||
f"Input '{input_name}' should be a dictionary."
|
||||
)
|
||||
assert "type" in input_content and "description" in input_content, (
|
||||
f"Input '{input_name}' should have keys 'type' and 'description', has only {list(input_content.keys())}."
|
||||
)
|
||||
if input_content["type"] not in AUTHORIZED_TYPES:
|
||||
raise Exception(
|
||||
f"Input '{input_name}': type '{input_content['type']}' is not an authorized value, should be one of {AUTHORIZED_TYPES}."
|
||||
|
@ -207,13 +207,13 @@ class Tool:
|
|||
json_schema = _convert_type_hints_to_json_schema(self.forward)
|
||||
for key, value in self.inputs.items():
|
||||
if "nullable" in value:
|
||||
assert (
|
||||
key in json_schema and "nullable" in json_schema[key]
|
||||
), f"Nullable argument '{key}' in inputs should have key 'nullable' set to True in function signature."
|
||||
assert key in json_schema and "nullable" in json_schema[key], (
|
||||
f"Nullable argument '{key}' in inputs should have key 'nullable' set to True in function signature."
|
||||
)
|
||||
if key in json_schema and "nullable" in json_schema[key]:
|
||||
assert (
|
||||
"nullable" in value
|
||||
), f"Nullable argument '{key}' in function signature should have key 'nullable' set to True in inputs."
|
||||
assert "nullable" in value, (
|
||||
f"Nullable argument '{key}' in function signature should have key 'nullable' set to True in inputs."
|
||||
)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return NotImplementedError("Write this method in your subclass of `Tool`.")
|
||||
|
@ -272,7 +272,7 @@ class Tool:
|
|||
class {class_name}(Tool):
|
||||
name = "{self.name}"
|
||||
description = "{self.description}"
|
||||
inputs = {json.dumps(self.inputs, separators=(',', ':'))}
|
||||
inputs = {json.dumps(self.inputs, separators=(",", ":"))}
|
||||
output_type = "{self.output_type}"
|
||||
""").strip()
|
||||
import re
|
||||
|
@ -439,7 +439,9 @@ class Tool:
|
|||
`cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the
|
||||
others will be passed along to its init.
|
||||
"""
|
||||
assert trust_remote_code, "Loading a tool from Hub requires to trust remote code. Make sure you've inspected the repo and pass `trust_remote_code=True` to load the tool."
|
||||
assert trust_remote_code, (
|
||||
"Loading a tool from Hub requires to trust remote code. Make sure you've inspected the repo and pass `trust_remote_code=True` to load the tool."
|
||||
)
|
||||
|
||||
hub_kwargs_names = [
|
||||
"cache_dir",
|
||||
|
@ -620,13 +622,10 @@ class Tool:
|
|||
arg.save(temp_file.name)
|
||||
arg = temp_file.name
|
||||
if (
|
||||
isinstance(arg, str)
|
||||
and os.path.isfile(arg)
|
||||
) or (
|
||||
isinstance(arg, Path)
|
||||
and arg.exists()
|
||||
and arg.is_file()
|
||||
) or is_http_url_like(arg):
|
||||
(isinstance(arg, str) and os.path.isfile(arg))
|
||||
or (isinstance(arg, Path) and arg.exists() and arg.is_file())
|
||||
or is_http_url_like(arg)
|
||||
):
|
||||
arg = handle_file(arg)
|
||||
return arg
|
||||
|
||||
|
|
|
@ -99,7 +99,7 @@ def parse_json_blob(json_blob: str) -> Dict[str, str]:
|
|||
raise ValueError(
|
||||
f"The JSON blob you used is invalid due to the following error: {e}.\n"
|
||||
f"JSON blob was: {json_blob}, decoding failed on that specific part of the blob:\n"
|
||||
f"'{json_blob[place-4:place+5]}'."
|
||||
f"'{json_blob[place - 4 : place + 5]}'."
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error in parsing the JSON blob: {e}")
|
||||
|
|
Loading…
Reference in New Issue