Replace max_iteration with max_steps for consistency
This commit is contained in:
		
							parent
							
								
									07015d12fe
								
							
						
					
					
						commit
						e9119c9df5
					
				|  | @ -120,7 +120,7 @@ Now that we have all the tools `search` and `visit_webpage`, we can use them to | ||||||
| 
 | 
 | ||||||
| Which configuration to choose for this agent? | Which configuration to choose for this agent? | ||||||
| - Web browsing is a single-timeline task that does not require parallel tool calls, so JSON tool calling works well for that. We thus choose a `JsonAgent`. | - Web browsing is a single-timeline task that does not require parallel tool calls, so JSON tool calling works well for that. We thus choose a `JsonAgent`. | ||||||
| - Also, since sometimes web search requires exploring many pages before finding the correct answer, we prefer to increase the number of `max_iterations` to 10. | - Also, since sometimes web search requires exploring many pages before finding the correct answer, we prefer to increase the number of `max_steps` to 10. | ||||||
| 
 | 
 | ||||||
| ```py | ```py | ||||||
| from smolagents import ( | from smolagents import ( | ||||||
|  | @ -137,7 +137,7 @@ model = HfApiModel(model_id) | ||||||
| web_agent = ToolCallingAgent( | web_agent = ToolCallingAgent( | ||||||
|     tools=[DuckDuckGoSearchTool(), visit_webpage], |     tools=[DuckDuckGoSearchTool(), visit_webpage], | ||||||
|     model=model, |     model=model, | ||||||
|     max_iterations=10, |     max_steps=10, | ||||||
| ) | ) | ||||||
| ``` | ``` | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -137,7 +137,7 @@ _Note:_ The Inference API hosts models based on various criteria, and deployed m | ||||||
| from smolagents import HfApiModel, CodeAgent | from smolagents import HfApiModel, CodeAgent | ||||||
| 
 | 
 | ||||||
| agent = CodeAgent( | agent = CodeAgent( | ||||||
|     tools=[retriever_tool], model=HfApiModel("meta-llama/Llama-3.3-70B-Instruct"), max_iterations=4, verbose=True |     tools=[retriever_tool], model=HfApiModel("meta-llama/Llama-3.3-70B-Instruct"), max_steps=4, verbose=True | ||||||
| ) | ) | ||||||
| ``` | ``` | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -125,6 +125,7 @@ print(model(messages)) | ||||||
| ### LiteLLMModel | ### LiteLLMModel | ||||||
| 
 | 
 | ||||||
| The `LiteLLMModel` leverages [LiteLLM](https://www.litellm.ai/) to support 100+ LLMs from various providers. | The `LiteLLMModel` leverages [LiteLLM](https://www.litellm.ai/) to support 100+ LLMs from various providers. | ||||||
|  | You can pass kwargs upon model initialization that will then be used whenever using the model, for instance below we pass `temperature`. | ||||||
| 
 | 
 | ||||||
| ```python | ```python | ||||||
| from smolagents import LiteLLMModel | from smolagents import LiteLLMModel | ||||||
|  | @ -135,8 +136,8 @@ messages = [ | ||||||
|   {"role": "user", "content": "No need to help, take it easy."}, |   {"role": "user", "content": "No need to help, take it easy."}, | ||||||
| ] | ] | ||||||
| 
 | 
 | ||||||
| model = LiteLLMModel("anthropic/claude-3-5-sonnet-latest") | model = LiteLLMModel("anthropic/claude-3-5-sonnet-latest", temperature=0.2) | ||||||
| print(model(messages)) | print(model(messages, max_tokens=10)) | ||||||
| ``` | ``` | ||||||
| 
 | 
 | ||||||
| [[autodoc]] LiteLLMModel | [[autodoc]] LiteLLMModel | ||||||
|  | @ -194,7 +194,7 @@ If after trying the above, you still want to change the system prompt, your new | ||||||
| Then you can change the system prompt as follows: | Then you can change the system prompt as follows: | ||||||
| 
 | 
 | ||||||
| ```py | ```py | ||||||
| from smolagents.prompts import CODE_SYSTEM_PROMPT, HfApiModel | from smolagents.prompts import CODE_SYSTEM_PROMPT | ||||||
| 
 | 
 | ||||||
| modified_system_prompt = CODE_SYSTEM_PROMPT + "\nHere you go!" # Change the system prompt here | modified_system_prompt = CODE_SYSTEM_PROMPT + "\nHere you go!" # Change the system prompt here | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -233,7 +233,7 @@ | ||||||
|     "    agent = ToolCallingAgent(\n", |     "    agent = ToolCallingAgent(\n", | ||||||
|     "        tools=[GoogleSearchTool(), VisitWebpageTool(), PythonInterpreterTool()],\n", |     "        tools=[GoogleSearchTool(), VisitWebpageTool(), PythonInterpreterTool()],\n", | ||||||
|     "        model=HfApiModel(model_id),\n", |     "        model=HfApiModel(model_id),\n", | ||||||
|     "        max_iterations=10,\n", |     "        max_steps=10,\n", | ||||||
|     "    )\n", |     "    )\n", | ||||||
|     "    file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n", |     "    file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n", | ||||||
|     "    answer_questions(eval_ds, file_name, agent, model_id, action_type)\n", |     "    answer_questions(eval_ds, file_name, agent, model_id, action_type)\n", | ||||||
|  | @ -243,7 +243,7 @@ | ||||||
|     "        tools=[GoogleSearchTool(), VisitWebpageTool()],\n", |     "        tools=[GoogleSearchTool(), VisitWebpageTool()],\n", | ||||||
|     "        model=HfApiModel(model_id),\n", |     "        model=HfApiModel(model_id),\n", | ||||||
|     "        additional_authorized_imports=[\"numpy\"],\n", |     "        additional_authorized_imports=[\"numpy\"],\n", | ||||||
|     "        max_iterations=10,\n", |     "        max_steps=10,\n", | ||||||
|     "    )\n", |     "    )\n", | ||||||
|     "    file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n", |     "    file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n", | ||||||
|     "    answer_questions(eval_ds, file_name, agent, model_id, action_type)" |     "    answer_questions(eval_ds, file_name, agent, model_id, action_type)" | ||||||
|  | @ -272,7 +272,7 @@ | ||||||
|     "    agent = ToolCallingAgent(\n", |     "    agent = ToolCallingAgent(\n", | ||||||
|     "        tools=[GoogleSearchTool(), VisitWebpageTool(), PythonInterpreterTool()],\n", |     "        tools=[GoogleSearchTool(), VisitWebpageTool(), PythonInterpreterTool()],\n", | ||||||
|     "        model=LiteLLMModel(model_id),\n", |     "        model=LiteLLMModel(model_id),\n", | ||||||
|     "        max_iterations=10,\n", |     "        max_steps=10,\n", | ||||||
|     "    )\n", |     "    )\n", | ||||||
|     "    file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n", |     "    file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n", | ||||||
|     "    answer_questions(eval_ds, file_name, agent, model_id, action_type)\n", |     "    answer_questions(eval_ds, file_name, agent, model_id, action_type)\n", | ||||||
|  | @ -282,7 +282,7 @@ | ||||||
|     "        tools=[GoogleSearchTool(), VisitWebpageTool()],\n", |     "        tools=[GoogleSearchTool(), VisitWebpageTool()],\n", | ||||||
|     "        model=LiteLLMModel(model_id),\n", |     "        model=LiteLLMModel(model_id),\n", | ||||||
|     "        additional_authorized_imports=[\"numpy\"],\n", |     "        additional_authorized_imports=[\"numpy\"],\n", | ||||||
|     "        max_iterations=10,\n", |     "        max_steps=10,\n", | ||||||
|     "    )\n", |     "    )\n", | ||||||
|     "    file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n", |     "    file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n", | ||||||
|     "    answer_questions(eval_ds, file_name, agent, model_id, action_type)" |     "    answer_questions(eval_ds, file_name, agent, model_id, action_type)" | ||||||
|  |  | ||||||
|  | @ -60,7 +60,7 @@ from smolagents import HfApiModel, CodeAgent | ||||||
| 
 | 
 | ||||||
| retriever_tool = RetrieverTool(docs_processed) | retriever_tool = RetrieverTool(docs_processed) | ||||||
| agent = CodeAgent( | agent = CodeAgent( | ||||||
|     tools=[retriever_tool], model=HfApiModel("meta-llama/Llama-3.3-70B-Instruct"), max_iterations=4, verbose=True |     tools=[retriever_tool], model=HfApiModel("meta-llama/Llama-3.3-70B-Instruct"), max_steps=4, verbose=True | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| agent_output = agent.run("For a transformers model training, which is slower, the forward or the backward pass?") | agent_output = agent.run("For a transformers model training, which is slower, the forward or the backward pass?") | ||||||
|  |  | ||||||
|  | @ -32,7 +32,7 @@ from .utils import ( | ||||||
|     AgentParsingError, |     AgentParsingError, | ||||||
|     AgentExecutionError, |     AgentExecutionError, | ||||||
|     AgentGenerationError, |     AgentGenerationError, | ||||||
|     AgentMaxIterationsError, |     AgentMaxStepsError, | ||||||
| ) | ) | ||||||
| from .types import AgentAudio, AgentImage, handle_agent_output_types | from .types import AgentAudio, AgentImage, handle_agent_output_types | ||||||
| from .default_tools import FinalAnswerTool | from .default_tools import FinalAnswerTool | ||||||
|  | @ -78,7 +78,7 @@ class ActionStep(AgentStep): | ||||||
|     tool_call: ToolCall | None = None |     tool_call: ToolCall | None = None | ||||||
|     start_time: float | None = None |     start_time: float | None = None | ||||||
|     end_time: float | None = None |     end_time: float | None = None | ||||||
|     iteration: int | None = None |     step: int | None = None | ||||||
|     error: AgentError | None = None |     error: AgentError | None = None | ||||||
|     duration: float | None = None |     duration: float | None = None | ||||||
|     llm_output: str | None = None |     llm_output: str | None = None | ||||||
|  | @ -163,7 +163,7 @@ class MultiStepAgent: | ||||||
|         model: Callable[[List[Dict[str, str]]], str], |         model: Callable[[List[Dict[str, str]]], str], | ||||||
|         system_prompt: Optional[str] = None, |         system_prompt: Optional[str] = None, | ||||||
|         tool_description_template: Optional[str] = None, |         tool_description_template: Optional[str] = None, | ||||||
|         max_iterations: int = 6, |         max_steps: int = 6, | ||||||
|         tool_parser: Optional[Callable] = None, |         tool_parser: Optional[Callable] = None, | ||||||
|         add_base_tools: bool = False, |         add_base_tools: bool = False, | ||||||
|         verbose: bool = False, |         verbose: bool = False, | ||||||
|  | @ -184,7 +184,7 @@ class MultiStepAgent: | ||||||
|             if tool_description_template |             if tool_description_template | ||||||
|             else DEFAULT_TOOL_DESCRIPTION_TEMPLATE |             else DEFAULT_TOOL_DESCRIPTION_TEMPLATE | ||||||
|         ) |         ) | ||||||
|         self.max_iterations = max_iterations |         self.max_steps = max_steps | ||||||
|         self.tool_parser = tool_parser |         self.tool_parser = tool_parser | ||||||
|         self.grammar = grammar |         self.grammar = grammar | ||||||
|         self.planning_interval = planning_interval |         self.planning_interval = planning_interval | ||||||
|  | @ -500,20 +500,18 @@ You have been provided with these additional arguments, that you can access usin | ||||||
|         Runs the agent in streaming mode, yielding steps as they are executed: should be launched only in the `run` method. |         Runs the agent in streaming mode, yielding steps as they are executed: should be launched only in the `run` method. | ||||||
|         """ |         """ | ||||||
|         final_answer = None |         final_answer = None | ||||||
|         iteration = 0 |         step = 0 | ||||||
|         while final_answer is None and iteration < self.max_iterations: |         while final_answer is None and step < self.max_steps: | ||||||
|             step_start_time = time.time() |             step_start_time = time.time() | ||||||
|             step_log = ActionStep(iteration=iteration, start_time=step_start_time) |             step_log = ActionStep(step=step, start_time=step_start_time) | ||||||
|             try: |             try: | ||||||
|                 if ( |                 if ( | ||||||
|                     self.planning_interval is not None |                     self.planning_interval is not None | ||||||
|                     and iteration % self.planning_interval == 0 |                     and step % self.planning_interval == 0 | ||||||
|                 ): |                 ): | ||||||
|                     self.planning_step( |                     self.planning_step(task, is_first_step=(step == 0), step=step) | ||||||
|                         task, is_first_step=(iteration == 0), iteration=iteration |  | ||||||
|                     ) |  | ||||||
|                 console.print( |                 console.print( | ||||||
|                     Rule(f"[bold]Step {iteration}", characters="━", style=YELLOW_HEX) |                     Rule(f"[bold]Step {step}", characters="━", style=YELLOW_HEX) | ||||||
|                 ) |                 ) | ||||||
| 
 | 
 | ||||||
|                 # Run one step! |                 # Run one step! | ||||||
|  | @ -526,12 +524,12 @@ You have been provided with these additional arguments, that you can access usin | ||||||
|                 self.logs.append(step_log) |                 self.logs.append(step_log) | ||||||
|                 for callback in self.step_callbacks: |                 for callback in self.step_callbacks: | ||||||
|                     callback(step_log) |                     callback(step_log) | ||||||
|                 iteration += 1 |                 step += 1 | ||||||
|                 yield step_log |                 yield step_log | ||||||
| 
 | 
 | ||||||
|         if final_answer is None and iteration == self.max_iterations: |         if final_answer is None and step == self.max_steps: | ||||||
|             error_message = "Reached max iterations." |             error_message = "Reached max steps." | ||||||
|             final_step_log = ActionStep(error=AgentMaxIterationsError(error_message)) |             final_step_log = ActionStep(error=AgentMaxStepsError(error_message)) | ||||||
|             self.logs.append(final_step_log) |             self.logs.append(final_step_log) | ||||||
|             final_answer = self.provide_final_answer(task) |             final_answer = self.provide_final_answer(task) | ||||||
|             console.print(Text(f"Final answer: {final_answer}")) |             console.print(Text(f"Final answer: {final_answer}")) | ||||||
|  | @ -549,20 +547,18 @@ You have been provided with these additional arguments, that you can access usin | ||||||
|         Runs the agent in direct mode, returning outputs only at the end: should be launched only in the `run` method. |         Runs the agent in direct mode, returning outputs only at the end: should be launched only in the `run` method. | ||||||
|         """ |         """ | ||||||
|         final_answer = None |         final_answer = None | ||||||
|         iteration = 0 |         step = 0 | ||||||
|         while final_answer is None and iteration < self.max_iterations: |         while final_answer is None and step < self.max_steps: | ||||||
|             step_start_time = time.time() |             step_start_time = time.time() | ||||||
|             step_log = ActionStep(iteration=iteration, start_time=step_start_time) |             step_log = ActionStep(step=step, start_time=step_start_time) | ||||||
|             try: |             try: | ||||||
|                 if ( |                 if ( | ||||||
|                     self.planning_interval is not None |                     self.planning_interval is not None | ||||||
|                     and iteration % self.planning_interval == 0 |                     and step % self.planning_interval == 0 | ||||||
|                 ): |                 ): | ||||||
|                     self.planning_step( |                     self.planning_step(task, is_first_step=(step == 0), step=step) | ||||||
|                         task, is_first_step=(iteration == 0), iteration=iteration |  | ||||||
|                     ) |  | ||||||
|                 console.print( |                 console.print( | ||||||
|                     Rule(f"[bold]Step {iteration}", characters="━", style=YELLOW_HEX) |                     Rule(f"[bold]Step {step}", characters="━", style=YELLOW_HEX) | ||||||
|                 ) |                 ) | ||||||
| 
 | 
 | ||||||
|                 # Run one step! |                 # Run one step! | ||||||
|  | @ -577,11 +573,11 @@ You have been provided with these additional arguments, that you can access usin | ||||||
|                 self.logs.append(step_log) |                 self.logs.append(step_log) | ||||||
|                 for callback in self.step_callbacks: |                 for callback in self.step_callbacks: | ||||||
|                     callback(step_log) |                     callback(step_log) | ||||||
|                 iteration += 1 |                 step += 1 | ||||||
| 
 | 
 | ||||||
|         if final_answer is None and iteration == self.max_iterations: |         if final_answer is None and step == self.max_steps: | ||||||
|             error_message = "Reached max iterations." |             error_message = "Reached max steps." | ||||||
|             final_step_log = ActionStep(error=AgentMaxIterationsError(error_message)) |             final_step_log = ActionStep(error=AgentMaxStepsError(error_message)) | ||||||
|             self.logs.append(final_step_log) |             self.logs.append(final_step_log) | ||||||
|             final_answer = self.provide_final_answer(task) |             final_answer = self.provide_final_answer(task) | ||||||
|             console.print(Text(f"Final answer: {final_answer}")) |             console.print(Text(f"Final answer: {final_answer}")) | ||||||
|  | @ -592,14 +588,14 @@ You have been provided with these additional arguments, that you can access usin | ||||||
| 
 | 
 | ||||||
|         return handle_agent_output_types(final_answer) |         return handle_agent_output_types(final_answer) | ||||||
| 
 | 
 | ||||||
|     def planning_step(self, task, is_first_step: bool, iteration: int): |     def planning_step(self, task, is_first_step: bool, step: int): | ||||||
|         """ |         """ | ||||||
|         Used periodically by the agent to plan the next steps to reach the objective. |         Used periodically by the agent to plan the next steps to reach the objective. | ||||||
| 
 | 
 | ||||||
|         Args: |         Args: | ||||||
|             task (`str`): The task to perform |             task (`str`): The task to perform | ||||||
|             is_first_step (`bool`): If this step is not the first one, the plan should be an update over a previous plan. |             is_first_step (`bool`): If this step is not the first one, the plan should be an update over a previous plan. | ||||||
|             iteration (`int`): The number of the current step, used as an indication for the LLM. |             step (`int`): The number of the current step, used as an indication for the LLM. | ||||||
|         """ |         """ | ||||||
|         if is_first_step: |         if is_first_step: | ||||||
|             message_prompt_facts = { |             message_prompt_facts = { | ||||||
|  | @ -687,7 +683,7 @@ Now begin!""", | ||||||
|                         show_agents_descriptions(self.managed_agents) |                         show_agents_descriptions(self.managed_agents) | ||||||
|                     ), |                     ), | ||||||
|                     facts_update=facts_update, |                     facts_update=facts_update, | ||||||
|                     remaining_steps=(self.max_iterations - iteration), |                     remaining_steps=(self.max_steps - step), | ||||||
|                 ), |                 ), | ||||||
|             } |             } | ||||||
|             plan_update = self.model( |             plan_update = self.model( | ||||||
|  |  | ||||||
|  | @ -310,7 +310,9 @@ class TransformersModel(Model): | ||||||
|                 f"Failed to load tokenizer and model for {model_id=}: {e}. Loading default tokenizer and model instead from {model_id=}." |                 f"Failed to load tokenizer and model for {model_id=}: {e}. Loading default tokenizer and model instead from {model_id=}." | ||||||
|             ) |             ) | ||||||
|             self.tokenizer = AutoTokenizer.from_pretrained(default_model_id) |             self.tokenizer = AutoTokenizer.from_pretrained(default_model_id) | ||||||
|             self.model = AutoModelForCausalLM.from_pretrained(default_model_id).to(self.device) |             self.model = AutoModelForCausalLM.from_pretrained(default_model_id).to( | ||||||
|  |                 self.device | ||||||
|  |             ) | ||||||
| 
 | 
 | ||||||
|     def make_stopping_criteria(self, stop_sequences: List[str]) -> StoppingCriteriaList: |     def make_stopping_criteria(self, stop_sequences: List[str]) -> StoppingCriteriaList: | ||||||
|         class StopOnStrings(StoppingCriteria): |         class StopOnStrings(StoppingCriteria): | ||||||
|  | @ -424,7 +426,7 @@ class LiteLLMModel(Model): | ||||||
|         model_id="anthropic/claude-3-5-sonnet-20240620", |         model_id="anthropic/claude-3-5-sonnet-20240620", | ||||||
|         api_base=None, |         api_base=None, | ||||||
|         api_key=None, |         api_key=None, | ||||||
|         **kwargs |         **kwargs, | ||||||
|     ): |     ): | ||||||
|         super().__init__() |         super().__init__() | ||||||
|         self.model_id = model_id |         self.model_id = model_id | ||||||
|  |  | ||||||
|  | @ -575,7 +575,6 @@ class Tool: | ||||||
|         from gradio_client import Client, handle_file |         from gradio_client import Client, handle_file | ||||||
| 
 | 
 | ||||||
|         class SpaceToolWrapper(Tool): |         class SpaceToolWrapper(Tool): | ||||||
|              |  | ||||||
|             skip_forward_signature_validation = True |             skip_forward_signature_validation = True | ||||||
| 
 | 
 | ||||||
|             def __init__( |             def __init__( | ||||||
|  | @ -907,12 +906,17 @@ class ToolCollection: | ||||||
|     ``` |     ``` | ||||||
|     """ |     """ | ||||||
| 
 | 
 | ||||||
|     def __init__(self, collection_slug: str, token: Optional[str] = None, trust_remote_code=False): |     def __init__( | ||||||
|  |         self, collection_slug: str, token: Optional[str] = None, trust_remote_code=False | ||||||
|  |     ): | ||||||
|         self._collection = get_collection(collection_slug, token=token) |         self._collection = get_collection(collection_slug, token=token) | ||||||
|         self._hub_repo_ids = { |         self._hub_repo_ids = { | ||||||
|             item.item_id for item in self._collection.items if item.item_type == "space" |             item.item_id for item in self._collection.items if item.item_type == "space" | ||||||
|         } |         } | ||||||
|         self.tools = {Tool.from_hub(repo_id,token,trust_remote_code) for repo_id in self._hub_repo_ids} |         self.tools = { | ||||||
|  |             Tool.from_hub(repo_id, token, trust_remote_code) | ||||||
|  |             for repo_id in self._hub_repo_ids | ||||||
|  |         } | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def tool(tool_function: Callable) -> Tool: | def tool(tool_function: Callable) -> Tool: | ||||||
|  |  | ||||||
|  | @ -31,7 +31,7 @@ def is_pygments_available(): | ||||||
|     return _pygments_available |     return _pygments_available | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| console = Console() | console = Console(width=200) | ||||||
| 
 | 
 | ||||||
| BASE_BUILTIN_MODULES = [ | BASE_BUILTIN_MODULES = [ | ||||||
|     "collections", |     "collections", | ||||||
|  | @ -69,7 +69,7 @@ class AgentExecutionError(AgentError): | ||||||
|     pass |     pass | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class AgentMaxIterationsError(AgentError): | class AgentMaxStepsError(AgentError): | ||||||
|     """Exception raised for errors in execution in the agent""" |     """Exception raised for errors in execution in the agent""" | ||||||
| 
 | 
 | ||||||
|     pass |     pass | ||||||
|  |  | ||||||
|  | @ -22,7 +22,7 @@ from pathlib import Path | ||||||
| 
 | 
 | ||||||
| from smolagents.types import AgentText, AgentImage | from smolagents.types import AgentText, AgentImage | ||||||
| from smolagents.agents import ( | from smolagents.agents import ( | ||||||
|     AgentMaxIterationsError, |     AgentMaxStepsError, | ||||||
|     ManagedAgent, |     ManagedAgent, | ||||||
|     CodeAgent, |     CodeAgent, | ||||||
|     ToolCallingAgent, |     ToolCallingAgent, | ||||||
|  | @ -279,15 +279,15 @@ class AgentTests(unittest.TestCase): | ||||||
|     def test_setup_agent_with_empty_toolbox(self): |     def test_setup_agent_with_empty_toolbox(self): | ||||||
|         ToolCallingAgent(model=FakeToolCallModel(), tools=[]) |         ToolCallingAgent(model=FakeToolCallModel(), tools=[]) | ||||||
| 
 | 
 | ||||||
|     def test_fails_max_iterations(self): |     def test_fails_max_steps(self): | ||||||
|         agent = CodeAgent( |         agent = CodeAgent( | ||||||
|             tools=[PythonInterpreterTool()], |             tools=[PythonInterpreterTool()], | ||||||
|             model=fake_code_model_no_return,  # use this callable because it never ends |             model=fake_code_model_no_return,  # use this callable because it never ends | ||||||
|             max_iterations=5, |             max_steps=5, | ||||||
|         ) |         ) | ||||||
|         agent.run("What is 2 multiplied by 3.6452?") |         agent.run("What is 2 multiplied by 3.6452?") | ||||||
|         assert len(agent.logs) == 8 |         assert len(agent.logs) == 8 | ||||||
|         assert type(agent.logs[-1].error) is AgentMaxIterationsError |         assert type(agent.logs[-1].error) is AgentMaxStepsError | ||||||
| 
 | 
 | ||||||
|     def test_init_agent_with_different_toolsets(self): |     def test_init_agent_with_different_toolsets(self): | ||||||
|         toolset_1 = [] |         toolset_1 = [] | ||||||
|  | @ -325,7 +325,7 @@ class AgentTests(unittest.TestCase): | ||||||
|         agent = CodeAgent( |         agent = CodeAgent( | ||||||
|             tools=[], |             tools=[], | ||||||
|             model=fake_code_functiondef, |             model=fake_code_functiondef, | ||||||
|             max_iterations=2, |             max_steps=2, | ||||||
|             additional_authorized_imports=["numpy"], |             additional_authorized_imports=["numpy"], | ||||||
|         ) |         ) | ||||||
|         res = agent.run("ok") |         res = agent.run("ok") | ||||||
|  |  | ||||||
|  | @ -41,7 +41,7 @@ final_answer('This is the final answer.') | ||||||
|         agent = CodeAgent( |         agent = CodeAgent( | ||||||
|             tools=[], |             tools=[], | ||||||
|             model=FakeLLMModel(), |             model=FakeLLMModel(), | ||||||
|             max_iterations=1, |             max_steps=1, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         agent.run("Fake task") |         agent.run("Fake task") | ||||||
|  | @ -61,7 +61,7 @@ final_answer('This is the final answer.') | ||||||
|         agent = ToolCallingAgent( |         agent = ToolCallingAgent( | ||||||
|             tools=[], |             tools=[], | ||||||
|             model=FakeLLMModel(), |             model=FakeLLMModel(), | ||||||
|             max_iterations=1, |             max_steps=1, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         agent.run("Fake task") |         agent.run("Fake task") | ||||||
|  | @ -69,7 +69,7 @@ final_answer('This is the final answer.') | ||||||
|         self.assertEqual(agent.monitor.total_input_token_count, 10) |         self.assertEqual(agent.monitor.total_input_token_count, 10) | ||||||
|         self.assertEqual(agent.monitor.total_output_token_count, 20) |         self.assertEqual(agent.monitor.total_output_token_count, 20) | ||||||
| 
 | 
 | ||||||
|     def test_code_agent_metrics_max_iterations(self): |     def test_code_agent_metrics_max_steps(self): | ||||||
|         class FakeLLMModel: |         class FakeLLMModel: | ||||||
|             def __init__(self): |             def __init__(self): | ||||||
|                 self.last_input_token_count = 10 |                 self.last_input_token_count = 10 | ||||||
|  | @ -81,7 +81,7 @@ final_answer('This is the final answer.') | ||||||
|         agent = CodeAgent( |         agent = CodeAgent( | ||||||
|             tools=[], |             tools=[], | ||||||
|             model=FakeLLMModel(), |             model=FakeLLMModel(), | ||||||
|             max_iterations=1, |             max_steps=1, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         agent.run("Fake task") |         agent.run("Fake task") | ||||||
|  | @ -103,7 +103,7 @@ final_answer('This is the final answer.') | ||||||
|         agent = CodeAgent( |         agent = CodeAgent( | ||||||
|             tools=[], |             tools=[], | ||||||
|             model=FakeLLMModel(), |             model=FakeLLMModel(), | ||||||
|             max_iterations=1, |             max_steps=1, | ||||||
|         ) |         ) | ||||||
|         agent.run("Fake task") |         agent.run("Fake task") | ||||||
| 
 | 
 | ||||||
|  | @ -123,7 +123,7 @@ final_answer('This is the final answer.') | ||||||
|         agent = CodeAgent( |         agent = CodeAgent( | ||||||
|             tools=[], |             tools=[], | ||||||
|             model=dummy_model, |             model=dummy_model, | ||||||
|             max_iterations=1, |             max_steps=1, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         # Use stream_to_gradio to capture the output |         # Use stream_to_gradio to capture the output | ||||||
|  | @ -145,7 +145,7 @@ final_answer('This is the final answer.') | ||||||
|         agent = ToolCallingAgent( |         agent = ToolCallingAgent( | ||||||
|             tools=[], |             tools=[], | ||||||
|             model=FakeLLM(), |             model=FakeLLM(), | ||||||
|             max_iterations=1, |             max_steps=1, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         # Use stream_to_gradio to capture the output |         # Use stream_to_gradio to capture the output | ||||||
|  | @ -172,7 +172,7 @@ final_answer('This is the final answer.') | ||||||
|         agent = CodeAgent( |         agent = CodeAgent( | ||||||
|             tools=[], |             tools=[], | ||||||
|             model=dummy_model, |             model=dummy_model, | ||||||
|             max_iterations=1, |             max_steps=1, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         # Use stream_to_gradio to capture the output |         # Use stream_to_gradio to capture the output | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue