Replace max_iteration with max_steps for consistency
This commit is contained in:
parent
07015d12fe
commit
e9119c9df5
|
@ -120,7 +120,7 @@ Now that we have all the tools `search` and `visit_webpage`, we can use them to
|
||||||
|
|
||||||
Which configuration to choose for this agent?
|
Which configuration to choose for this agent?
|
||||||
- Web browsing is a single-timeline task that does not require parallel tool calls, so JSON tool calling works well for that. We thus choose a `JsonAgent`.
|
- Web browsing is a single-timeline task that does not require parallel tool calls, so JSON tool calling works well for that. We thus choose a `JsonAgent`.
|
||||||
- Also, since sometimes web search requires exploring many pages before finding the correct answer, we prefer to increase the number of `max_iterations` to 10.
|
- Also, since sometimes web search requires exploring many pages before finding the correct answer, we prefer to increase the number of `max_steps` to 10.
|
||||||
|
|
||||||
```py
|
```py
|
||||||
from smolagents import (
|
from smolagents import (
|
||||||
|
@ -137,7 +137,7 @@ model = HfApiModel(model_id)
|
||||||
web_agent = ToolCallingAgent(
|
web_agent = ToolCallingAgent(
|
||||||
tools=[DuckDuckGoSearchTool(), visit_webpage],
|
tools=[DuckDuckGoSearchTool(), visit_webpage],
|
||||||
model=model,
|
model=model,
|
||||||
max_iterations=10,
|
max_steps=10,
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -137,7 +137,7 @@ _Note:_ The Inference API hosts models based on various criteria, and deployed m
|
||||||
from smolagents import HfApiModel, CodeAgent
|
from smolagents import HfApiModel, CodeAgent
|
||||||
|
|
||||||
agent = CodeAgent(
|
agent = CodeAgent(
|
||||||
tools=[retriever_tool], model=HfApiModel("meta-llama/Llama-3.3-70B-Instruct"), max_iterations=4, verbose=True
|
tools=[retriever_tool], model=HfApiModel("meta-llama/Llama-3.3-70B-Instruct"), max_steps=4, verbose=True
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -125,6 +125,7 @@ print(model(messages))
|
||||||
### LiteLLMModel
|
### LiteLLMModel
|
||||||
|
|
||||||
The `LiteLLMModel` leverages [LiteLLM](https://www.litellm.ai/) to support 100+ LLMs from various providers.
|
The `LiteLLMModel` leverages [LiteLLM](https://www.litellm.ai/) to support 100+ LLMs from various providers.
|
||||||
|
You can pass kwargs upon model initialization that will then be used whenever using the model, for instance below we pass `temperature`.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from smolagents import LiteLLMModel
|
from smolagents import LiteLLMModel
|
||||||
|
@ -135,8 +136,8 @@ messages = [
|
||||||
{"role": "user", "content": "No need to help, take it easy."},
|
{"role": "user", "content": "No need to help, take it easy."},
|
||||||
]
|
]
|
||||||
|
|
||||||
model = LiteLLMModel("anthropic/claude-3-5-sonnet-latest")
|
model = LiteLLMModel("anthropic/claude-3-5-sonnet-latest", temperature=0.2)
|
||||||
print(model(messages))
|
print(model(messages, max_tokens=10))
|
||||||
```
|
```
|
||||||
|
|
||||||
[[autodoc]] LiteLLMModel
|
[[autodoc]] LiteLLMModel
|
|
@ -194,7 +194,7 @@ If after trying the above, you still want to change the system prompt, your new
|
||||||
Then you can change the system prompt as follows:
|
Then you can change the system prompt as follows:
|
||||||
|
|
||||||
```py
|
```py
|
||||||
from smolagents.prompts import CODE_SYSTEM_PROMPT, HfApiModel
|
from smolagents.prompts import CODE_SYSTEM_PROMPT
|
||||||
|
|
||||||
modified_system_prompt = CODE_SYSTEM_PROMPT + "\nHere you go!" # Change the system prompt here
|
modified_system_prompt = CODE_SYSTEM_PROMPT + "\nHere you go!" # Change the system prompt here
|
||||||
|
|
||||||
|
|
|
@ -233,7 +233,7 @@
|
||||||
" agent = ToolCallingAgent(\n",
|
" agent = ToolCallingAgent(\n",
|
||||||
" tools=[GoogleSearchTool(), VisitWebpageTool(), PythonInterpreterTool()],\n",
|
" tools=[GoogleSearchTool(), VisitWebpageTool(), PythonInterpreterTool()],\n",
|
||||||
" model=HfApiModel(model_id),\n",
|
" model=HfApiModel(model_id),\n",
|
||||||
" max_iterations=10,\n",
|
" max_steps=10,\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
" file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n",
|
" file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n",
|
||||||
" answer_questions(eval_ds, file_name, agent, model_id, action_type)\n",
|
" answer_questions(eval_ds, file_name, agent, model_id, action_type)\n",
|
||||||
|
@ -243,7 +243,7 @@
|
||||||
" tools=[GoogleSearchTool(), VisitWebpageTool()],\n",
|
" tools=[GoogleSearchTool(), VisitWebpageTool()],\n",
|
||||||
" model=HfApiModel(model_id),\n",
|
" model=HfApiModel(model_id),\n",
|
||||||
" additional_authorized_imports=[\"numpy\"],\n",
|
" additional_authorized_imports=[\"numpy\"],\n",
|
||||||
" max_iterations=10,\n",
|
" max_steps=10,\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
" file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n",
|
" file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n",
|
||||||
" answer_questions(eval_ds, file_name, agent, model_id, action_type)"
|
" answer_questions(eval_ds, file_name, agent, model_id, action_type)"
|
||||||
|
@ -272,7 +272,7 @@
|
||||||
" agent = ToolCallingAgent(\n",
|
" agent = ToolCallingAgent(\n",
|
||||||
" tools=[GoogleSearchTool(), VisitWebpageTool(), PythonInterpreterTool()],\n",
|
" tools=[GoogleSearchTool(), VisitWebpageTool(), PythonInterpreterTool()],\n",
|
||||||
" model=LiteLLMModel(model_id),\n",
|
" model=LiteLLMModel(model_id),\n",
|
||||||
" max_iterations=10,\n",
|
" max_steps=10,\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
" file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n",
|
" file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n",
|
||||||
" answer_questions(eval_ds, file_name, agent, model_id, action_type)\n",
|
" answer_questions(eval_ds, file_name, agent, model_id, action_type)\n",
|
||||||
|
@ -282,7 +282,7 @@
|
||||||
" tools=[GoogleSearchTool(), VisitWebpageTool()],\n",
|
" tools=[GoogleSearchTool(), VisitWebpageTool()],\n",
|
||||||
" model=LiteLLMModel(model_id),\n",
|
" model=LiteLLMModel(model_id),\n",
|
||||||
" additional_authorized_imports=[\"numpy\"],\n",
|
" additional_authorized_imports=[\"numpy\"],\n",
|
||||||
" max_iterations=10,\n",
|
" max_steps=10,\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
" file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n",
|
" file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n",
|
||||||
" answer_questions(eval_ds, file_name, agent, model_id, action_type)"
|
" answer_questions(eval_ds, file_name, agent, model_id, action_type)"
|
||||||
|
|
|
@ -60,7 +60,7 @@ from smolagents import HfApiModel, CodeAgent
|
||||||
|
|
||||||
retriever_tool = RetrieverTool(docs_processed)
|
retriever_tool = RetrieverTool(docs_processed)
|
||||||
agent = CodeAgent(
|
agent = CodeAgent(
|
||||||
tools=[retriever_tool], model=HfApiModel("meta-llama/Llama-3.3-70B-Instruct"), max_iterations=4, verbose=True
|
tools=[retriever_tool], model=HfApiModel("meta-llama/Llama-3.3-70B-Instruct"), max_steps=4, verbose=True
|
||||||
)
|
)
|
||||||
|
|
||||||
agent_output = agent.run("For a transformers model training, which is slower, the forward or the backward pass?")
|
agent_output = agent.run("For a transformers model training, which is slower, the forward or the backward pass?")
|
||||||
|
|
|
@ -32,7 +32,7 @@ from .utils import (
|
||||||
AgentParsingError,
|
AgentParsingError,
|
||||||
AgentExecutionError,
|
AgentExecutionError,
|
||||||
AgentGenerationError,
|
AgentGenerationError,
|
||||||
AgentMaxIterationsError,
|
AgentMaxStepsError,
|
||||||
)
|
)
|
||||||
from .types import AgentAudio, AgentImage, handle_agent_output_types
|
from .types import AgentAudio, AgentImage, handle_agent_output_types
|
||||||
from .default_tools import FinalAnswerTool
|
from .default_tools import FinalAnswerTool
|
||||||
|
@ -78,7 +78,7 @@ class ActionStep(AgentStep):
|
||||||
tool_call: ToolCall | None = None
|
tool_call: ToolCall | None = None
|
||||||
start_time: float | None = None
|
start_time: float | None = None
|
||||||
end_time: float | None = None
|
end_time: float | None = None
|
||||||
iteration: int | None = None
|
step: int | None = None
|
||||||
error: AgentError | None = None
|
error: AgentError | None = None
|
||||||
duration: float | None = None
|
duration: float | None = None
|
||||||
llm_output: str | None = None
|
llm_output: str | None = None
|
||||||
|
@ -163,7 +163,7 @@ class MultiStepAgent:
|
||||||
model: Callable[[List[Dict[str, str]]], str],
|
model: Callable[[List[Dict[str, str]]], str],
|
||||||
system_prompt: Optional[str] = None,
|
system_prompt: Optional[str] = None,
|
||||||
tool_description_template: Optional[str] = None,
|
tool_description_template: Optional[str] = None,
|
||||||
max_iterations: int = 6,
|
max_steps: int = 6,
|
||||||
tool_parser: Optional[Callable] = None,
|
tool_parser: Optional[Callable] = None,
|
||||||
add_base_tools: bool = False,
|
add_base_tools: bool = False,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
|
@ -184,7 +184,7 @@ class MultiStepAgent:
|
||||||
if tool_description_template
|
if tool_description_template
|
||||||
else DEFAULT_TOOL_DESCRIPTION_TEMPLATE
|
else DEFAULT_TOOL_DESCRIPTION_TEMPLATE
|
||||||
)
|
)
|
||||||
self.max_iterations = max_iterations
|
self.max_steps = max_steps
|
||||||
self.tool_parser = tool_parser
|
self.tool_parser = tool_parser
|
||||||
self.grammar = grammar
|
self.grammar = grammar
|
||||||
self.planning_interval = planning_interval
|
self.planning_interval = planning_interval
|
||||||
|
@ -500,20 +500,18 @@ You have been provided with these additional arguments, that you can access usin
|
||||||
Runs the agent in streaming mode, yielding steps as they are executed: should be launched only in the `run` method.
|
Runs the agent in streaming mode, yielding steps as they are executed: should be launched only in the `run` method.
|
||||||
"""
|
"""
|
||||||
final_answer = None
|
final_answer = None
|
||||||
iteration = 0
|
step = 0
|
||||||
while final_answer is None and iteration < self.max_iterations:
|
while final_answer is None and step < self.max_steps:
|
||||||
step_start_time = time.time()
|
step_start_time = time.time()
|
||||||
step_log = ActionStep(iteration=iteration, start_time=step_start_time)
|
step_log = ActionStep(step=step, start_time=step_start_time)
|
||||||
try:
|
try:
|
||||||
if (
|
if (
|
||||||
self.planning_interval is not None
|
self.planning_interval is not None
|
||||||
and iteration % self.planning_interval == 0
|
and step % self.planning_interval == 0
|
||||||
):
|
):
|
||||||
self.planning_step(
|
self.planning_step(task, is_first_step=(step == 0), step=step)
|
||||||
task, is_first_step=(iteration == 0), iteration=iteration
|
|
||||||
)
|
|
||||||
console.print(
|
console.print(
|
||||||
Rule(f"[bold]Step {iteration}", characters="━", style=YELLOW_HEX)
|
Rule(f"[bold]Step {step}", characters="━", style=YELLOW_HEX)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Run one step!
|
# Run one step!
|
||||||
|
@ -526,12 +524,12 @@ You have been provided with these additional arguments, that you can access usin
|
||||||
self.logs.append(step_log)
|
self.logs.append(step_log)
|
||||||
for callback in self.step_callbacks:
|
for callback in self.step_callbacks:
|
||||||
callback(step_log)
|
callback(step_log)
|
||||||
iteration += 1
|
step += 1
|
||||||
yield step_log
|
yield step_log
|
||||||
|
|
||||||
if final_answer is None and iteration == self.max_iterations:
|
if final_answer is None and step == self.max_steps:
|
||||||
error_message = "Reached max iterations."
|
error_message = "Reached max steps."
|
||||||
final_step_log = ActionStep(error=AgentMaxIterationsError(error_message))
|
final_step_log = ActionStep(error=AgentMaxStepsError(error_message))
|
||||||
self.logs.append(final_step_log)
|
self.logs.append(final_step_log)
|
||||||
final_answer = self.provide_final_answer(task)
|
final_answer = self.provide_final_answer(task)
|
||||||
console.print(Text(f"Final answer: {final_answer}"))
|
console.print(Text(f"Final answer: {final_answer}"))
|
||||||
|
@ -549,20 +547,18 @@ You have been provided with these additional arguments, that you can access usin
|
||||||
Runs the agent in direct mode, returning outputs only at the end: should be launched only in the `run` method.
|
Runs the agent in direct mode, returning outputs only at the end: should be launched only in the `run` method.
|
||||||
"""
|
"""
|
||||||
final_answer = None
|
final_answer = None
|
||||||
iteration = 0
|
step = 0
|
||||||
while final_answer is None and iteration < self.max_iterations:
|
while final_answer is None and step < self.max_steps:
|
||||||
step_start_time = time.time()
|
step_start_time = time.time()
|
||||||
step_log = ActionStep(iteration=iteration, start_time=step_start_time)
|
step_log = ActionStep(step=step, start_time=step_start_time)
|
||||||
try:
|
try:
|
||||||
if (
|
if (
|
||||||
self.planning_interval is not None
|
self.planning_interval is not None
|
||||||
and iteration % self.planning_interval == 0
|
and step % self.planning_interval == 0
|
||||||
):
|
):
|
||||||
self.planning_step(
|
self.planning_step(task, is_first_step=(step == 0), step=step)
|
||||||
task, is_first_step=(iteration == 0), iteration=iteration
|
|
||||||
)
|
|
||||||
console.print(
|
console.print(
|
||||||
Rule(f"[bold]Step {iteration}", characters="━", style=YELLOW_HEX)
|
Rule(f"[bold]Step {step}", characters="━", style=YELLOW_HEX)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Run one step!
|
# Run one step!
|
||||||
|
@ -577,11 +573,11 @@ You have been provided with these additional arguments, that you can access usin
|
||||||
self.logs.append(step_log)
|
self.logs.append(step_log)
|
||||||
for callback in self.step_callbacks:
|
for callback in self.step_callbacks:
|
||||||
callback(step_log)
|
callback(step_log)
|
||||||
iteration += 1
|
step += 1
|
||||||
|
|
||||||
if final_answer is None and iteration == self.max_iterations:
|
if final_answer is None and step == self.max_steps:
|
||||||
error_message = "Reached max iterations."
|
error_message = "Reached max steps."
|
||||||
final_step_log = ActionStep(error=AgentMaxIterationsError(error_message))
|
final_step_log = ActionStep(error=AgentMaxStepsError(error_message))
|
||||||
self.logs.append(final_step_log)
|
self.logs.append(final_step_log)
|
||||||
final_answer = self.provide_final_answer(task)
|
final_answer = self.provide_final_answer(task)
|
||||||
console.print(Text(f"Final answer: {final_answer}"))
|
console.print(Text(f"Final answer: {final_answer}"))
|
||||||
|
@ -592,14 +588,14 @@ You have been provided with these additional arguments, that you can access usin
|
||||||
|
|
||||||
return handle_agent_output_types(final_answer)
|
return handle_agent_output_types(final_answer)
|
||||||
|
|
||||||
def planning_step(self, task, is_first_step: bool, iteration: int):
|
def planning_step(self, task, is_first_step: bool, step: int):
|
||||||
"""
|
"""
|
||||||
Used periodically by the agent to plan the next steps to reach the objective.
|
Used periodically by the agent to plan the next steps to reach the objective.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
task (`str`): The task to perform
|
task (`str`): The task to perform
|
||||||
is_first_step (`bool`): If this step is not the first one, the plan should be an update over a previous plan.
|
is_first_step (`bool`): If this step is not the first one, the plan should be an update over a previous plan.
|
||||||
iteration (`int`): The number of the current step, used as an indication for the LLM.
|
step (`int`): The number of the current step, used as an indication for the LLM.
|
||||||
"""
|
"""
|
||||||
if is_first_step:
|
if is_first_step:
|
||||||
message_prompt_facts = {
|
message_prompt_facts = {
|
||||||
|
@ -687,7 +683,7 @@ Now begin!""",
|
||||||
show_agents_descriptions(self.managed_agents)
|
show_agents_descriptions(self.managed_agents)
|
||||||
),
|
),
|
||||||
facts_update=facts_update,
|
facts_update=facts_update,
|
||||||
remaining_steps=(self.max_iterations - iteration),
|
remaining_steps=(self.max_steps - step),
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
plan_update = self.model(
|
plan_update = self.model(
|
||||||
|
|
|
@ -281,12 +281,12 @@ class HfApiModel(Model):
|
||||||
|
|
||||||
class TransformersModel(Model):
|
class TransformersModel(Model):
|
||||||
"""This engine initializes a model and tokenizer from the given `model_id`.
|
"""This engine initializes a model and tokenizer from the given `model_id`.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
model_id (`str`, *optional*, defaults to `"HuggingFaceTB/SmolLM2-1.7B-Instruct"`):
|
model_id (`str`, *optional*, defaults to `"HuggingFaceTB/SmolLM2-1.7B-Instruct"`):
|
||||||
The Hugging Face model ID to be used for inference. This can be a path or model identifier from the Hugging Face model hub.
|
The Hugging Face model ID to be used for inference. This can be a path or model identifier from the Hugging Face model hub.
|
||||||
device (`str`, optional, defaults to `"cuda"` if available, else `"cpu"`.):
|
device (`str`, optional, defaults to `"cuda"` if available, else `"cpu"`.):
|
||||||
The device to load the model on (`"cpu"` or `"cuda"`).
|
The device to load the model on (`"cpu"` or `"cuda"`).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model_id: Optional[str] = None, device: Optional[str] = None):
|
def __init__(self, model_id: Optional[str] = None, device: Optional[str] = None):
|
||||||
|
@ -310,7 +310,9 @@ class TransformersModel(Model):
|
||||||
f"Failed to load tokenizer and model for {model_id=}: {e}. Loading default tokenizer and model instead from {model_id=}."
|
f"Failed to load tokenizer and model for {model_id=}: {e}. Loading default tokenizer and model instead from {model_id=}."
|
||||||
)
|
)
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(default_model_id)
|
self.tokenizer = AutoTokenizer.from_pretrained(default_model_id)
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(default_model_id).to(self.device)
|
self.model = AutoModelForCausalLM.from_pretrained(default_model_id).to(
|
||||||
|
self.device
|
||||||
|
)
|
||||||
|
|
||||||
def make_stopping_criteria(self, stop_sequences: List[str]) -> StoppingCriteriaList:
|
def make_stopping_criteria(self, stop_sequences: List[str]) -> StoppingCriteriaList:
|
||||||
class StopOnStrings(StoppingCriteria):
|
class StopOnStrings(StoppingCriteria):
|
||||||
|
@ -424,7 +426,7 @@ class LiteLLMModel(Model):
|
||||||
model_id="anthropic/claude-3-5-sonnet-20240620",
|
model_id="anthropic/claude-3-5-sonnet-20240620",
|
||||||
api_base=None,
|
api_base=None,
|
||||||
api_key=None,
|
api_key=None,
|
||||||
**kwargs
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
|
|
|
@ -575,7 +575,6 @@ class Tool:
|
||||||
from gradio_client import Client, handle_file
|
from gradio_client import Client, handle_file
|
||||||
|
|
||||||
class SpaceToolWrapper(Tool):
|
class SpaceToolWrapper(Tool):
|
||||||
|
|
||||||
skip_forward_signature_validation = True
|
skip_forward_signature_validation = True
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -907,12 +906,17 @@ class ToolCollection:
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, collection_slug: str, token: Optional[str] = None, trust_remote_code=False):
|
def __init__(
|
||||||
|
self, collection_slug: str, token: Optional[str] = None, trust_remote_code=False
|
||||||
|
):
|
||||||
self._collection = get_collection(collection_slug, token=token)
|
self._collection = get_collection(collection_slug, token=token)
|
||||||
self._hub_repo_ids = {
|
self._hub_repo_ids = {
|
||||||
item.item_id for item in self._collection.items if item.item_type == "space"
|
item.item_id for item in self._collection.items if item.item_type == "space"
|
||||||
}
|
}
|
||||||
self.tools = {Tool.from_hub(repo_id,token,trust_remote_code) for repo_id in self._hub_repo_ids}
|
self.tools = {
|
||||||
|
Tool.from_hub(repo_id, token, trust_remote_code)
|
||||||
|
for repo_id in self._hub_repo_ids
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def tool(tool_function: Callable) -> Tool:
|
def tool(tool_function: Callable) -> Tool:
|
||||||
|
|
|
@ -31,7 +31,7 @@ def is_pygments_available():
|
||||||
return _pygments_available
|
return _pygments_available
|
||||||
|
|
||||||
|
|
||||||
console = Console()
|
console = Console(width=200)
|
||||||
|
|
||||||
BASE_BUILTIN_MODULES = [
|
BASE_BUILTIN_MODULES = [
|
||||||
"collections",
|
"collections",
|
||||||
|
@ -69,7 +69,7 @@ class AgentExecutionError(AgentError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class AgentMaxIterationsError(AgentError):
|
class AgentMaxStepsError(AgentError):
|
||||||
"""Exception raised for errors in execution in the agent"""
|
"""Exception raised for errors in execution in the agent"""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -22,7 +22,7 @@ from pathlib import Path
|
||||||
|
|
||||||
from smolagents.types import AgentText, AgentImage
|
from smolagents.types import AgentText, AgentImage
|
||||||
from smolagents.agents import (
|
from smolagents.agents import (
|
||||||
AgentMaxIterationsError,
|
AgentMaxStepsError,
|
||||||
ManagedAgent,
|
ManagedAgent,
|
||||||
CodeAgent,
|
CodeAgent,
|
||||||
ToolCallingAgent,
|
ToolCallingAgent,
|
||||||
|
@ -279,15 +279,15 @@ class AgentTests(unittest.TestCase):
|
||||||
def test_setup_agent_with_empty_toolbox(self):
|
def test_setup_agent_with_empty_toolbox(self):
|
||||||
ToolCallingAgent(model=FakeToolCallModel(), tools=[])
|
ToolCallingAgent(model=FakeToolCallModel(), tools=[])
|
||||||
|
|
||||||
def test_fails_max_iterations(self):
|
def test_fails_max_steps(self):
|
||||||
agent = CodeAgent(
|
agent = CodeAgent(
|
||||||
tools=[PythonInterpreterTool()],
|
tools=[PythonInterpreterTool()],
|
||||||
model=fake_code_model_no_return, # use this callable because it never ends
|
model=fake_code_model_no_return, # use this callable because it never ends
|
||||||
max_iterations=5,
|
max_steps=5,
|
||||||
)
|
)
|
||||||
agent.run("What is 2 multiplied by 3.6452?")
|
agent.run("What is 2 multiplied by 3.6452?")
|
||||||
assert len(agent.logs) == 8
|
assert len(agent.logs) == 8
|
||||||
assert type(agent.logs[-1].error) is AgentMaxIterationsError
|
assert type(agent.logs[-1].error) is AgentMaxStepsError
|
||||||
|
|
||||||
def test_init_agent_with_different_toolsets(self):
|
def test_init_agent_with_different_toolsets(self):
|
||||||
toolset_1 = []
|
toolset_1 = []
|
||||||
|
@ -325,7 +325,7 @@ class AgentTests(unittest.TestCase):
|
||||||
agent = CodeAgent(
|
agent = CodeAgent(
|
||||||
tools=[],
|
tools=[],
|
||||||
model=fake_code_functiondef,
|
model=fake_code_functiondef,
|
||||||
max_iterations=2,
|
max_steps=2,
|
||||||
additional_authorized_imports=["numpy"],
|
additional_authorized_imports=["numpy"],
|
||||||
)
|
)
|
||||||
res = agent.run("ok")
|
res = agent.run("ok")
|
||||||
|
|
|
@ -41,7 +41,7 @@ final_answer('This is the final answer.')
|
||||||
agent = CodeAgent(
|
agent = CodeAgent(
|
||||||
tools=[],
|
tools=[],
|
||||||
model=FakeLLMModel(),
|
model=FakeLLMModel(),
|
||||||
max_iterations=1,
|
max_steps=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
agent.run("Fake task")
|
agent.run("Fake task")
|
||||||
|
@ -61,7 +61,7 @@ final_answer('This is the final answer.')
|
||||||
agent = ToolCallingAgent(
|
agent = ToolCallingAgent(
|
||||||
tools=[],
|
tools=[],
|
||||||
model=FakeLLMModel(),
|
model=FakeLLMModel(),
|
||||||
max_iterations=1,
|
max_steps=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
agent.run("Fake task")
|
agent.run("Fake task")
|
||||||
|
@ -69,7 +69,7 @@ final_answer('This is the final answer.')
|
||||||
self.assertEqual(agent.monitor.total_input_token_count, 10)
|
self.assertEqual(agent.monitor.total_input_token_count, 10)
|
||||||
self.assertEqual(agent.monitor.total_output_token_count, 20)
|
self.assertEqual(agent.monitor.total_output_token_count, 20)
|
||||||
|
|
||||||
def test_code_agent_metrics_max_iterations(self):
|
def test_code_agent_metrics_max_steps(self):
|
||||||
class FakeLLMModel:
|
class FakeLLMModel:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.last_input_token_count = 10
|
self.last_input_token_count = 10
|
||||||
|
@ -81,7 +81,7 @@ final_answer('This is the final answer.')
|
||||||
agent = CodeAgent(
|
agent = CodeAgent(
|
||||||
tools=[],
|
tools=[],
|
||||||
model=FakeLLMModel(),
|
model=FakeLLMModel(),
|
||||||
max_iterations=1,
|
max_steps=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
agent.run("Fake task")
|
agent.run("Fake task")
|
||||||
|
@ -103,7 +103,7 @@ final_answer('This is the final answer.')
|
||||||
agent = CodeAgent(
|
agent = CodeAgent(
|
||||||
tools=[],
|
tools=[],
|
||||||
model=FakeLLMModel(),
|
model=FakeLLMModel(),
|
||||||
max_iterations=1,
|
max_steps=1,
|
||||||
)
|
)
|
||||||
agent.run("Fake task")
|
agent.run("Fake task")
|
||||||
|
|
||||||
|
@ -123,7 +123,7 @@ final_answer('This is the final answer.')
|
||||||
agent = CodeAgent(
|
agent = CodeAgent(
|
||||||
tools=[],
|
tools=[],
|
||||||
model=dummy_model,
|
model=dummy_model,
|
||||||
max_iterations=1,
|
max_steps=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use stream_to_gradio to capture the output
|
# Use stream_to_gradio to capture the output
|
||||||
|
@ -145,7 +145,7 @@ final_answer('This is the final answer.')
|
||||||
agent = ToolCallingAgent(
|
agent = ToolCallingAgent(
|
||||||
tools=[],
|
tools=[],
|
||||||
model=FakeLLM(),
|
model=FakeLLM(),
|
||||||
max_iterations=1,
|
max_steps=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use stream_to_gradio to capture the output
|
# Use stream_to_gradio to capture the output
|
||||||
|
@ -172,7 +172,7 @@ final_answer('This is the final answer.')
|
||||||
agent = CodeAgent(
|
agent = CodeAgent(
|
||||||
tools=[],
|
tools=[],
|
||||||
model=dummy_model,
|
model=dummy_model,
|
||||||
max_iterations=1,
|
max_steps=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use stream_to_gradio to capture the output
|
# Use stream_to_gradio to capture the output
|
||||||
|
|
Loading…
Reference in New Issue