Vastly simplify Model class ✨ (#146)
* Vastly simplify Model class by making only one __call__ method ✨
			
			
This commit is contained in:
		
							parent
							
								
									36ed279c85
								
							
						
					
					
						commit
						5c33130fa4
					
				|  | @ -113,8 +113,7 @@ The Python interpreter also doesn't allow imports by default outside of a safe l | ||||||
| You can authorize additional imports by passing the authorized modules as a list of strings in argument `additional_authorized_imports` upon initialization of your [`CodeAgent`]: | You can authorize additional imports by passing the authorized modules as a list of strings in argument `additional_authorized_imports` upon initialization of your [`CodeAgent`]: | ||||||
| 
 | 
 | ||||||
| ```py | ```py | ||||||
| from smolagents import CodeAgent | model = HfApiModel() | ||||||
| 
 |  | ||||||
| agent = CodeAgent(tools=[], model=model, additional_authorized_imports=['requests', 'bs4']) | agent = CodeAgent(tools=[], model=model, additional_authorized_imports=['requests', 'bs4']) | ||||||
| agent.run("Could you get me the title of the page at url 'https://huggingface.co/blog'?") | agent.run("Could you get me the title of the page at url 'https://huggingface.co/blog'?") | ||||||
| ``` | ``` | ||||||
|  | @ -164,12 +163,12 @@ Transformers comes with a default toolbox for empowering agents, that you can ad | ||||||
| - **Python code interpreter**: runs your the LLM generated Python code in a secure environment. This tool will only be added to [`ToolCallingAgent`] if you initialize it with `add_base_tools=True`, since code-based agent can already natively execute Python code | - **Python code interpreter**: runs your the LLM generated Python code in a secure environment. This tool will only be added to [`ToolCallingAgent`] if you initialize it with `add_base_tools=True`, since code-based agent can already natively execute Python code | ||||||
| - **Transcriber**: a speech-to-text pipeline built on Whisper-Turbo that transcribes an audio to text. | - **Transcriber**: a speech-to-text pipeline built on Whisper-Turbo that transcribes an audio to text. | ||||||
| 
 | 
 | ||||||
| You can manually use a tool by calling the [`load_tool`] function and a task to perform. | You can manually use a tool by calling it with its arguments. | ||||||
| 
 | 
 | ||||||
| ```python | ```python | ||||||
| from smolagents import load_tool | from smolagents import DuckDuckGoSearchTool | ||||||
| 
 | 
 | ||||||
| search_tool = load_tool("web_search") | search_tool = DuckDuckGoSearchTool() | ||||||
| print(search_tool("Who's the current president of Russia?")) | print(search_tool("Who's the current president of Russia?")) | ||||||
| ``` | ``` | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -776,11 +776,15 @@ class ToolCallingAgent(MultiStepAgent): | ||||||
|         log_entry.agent_memory = agent_memory.copy() |         log_entry.agent_memory = agent_memory.copy() | ||||||
| 
 | 
 | ||||||
|         try: |         try: | ||||||
|             tool_name, tool_arguments, tool_call_id = self.model.get_tool_call( |             model_message = self.model( | ||||||
|                 self.input_messages, |                 self.input_messages, | ||||||
|                 available_tools=list(self.tools.values()), |                 tools_to_call_from=list(self.tools.values()), | ||||||
|                 stop_sequences=["Observation:"], |                 stop_sequences=["Observation:"], | ||||||
|             ) |             ) | ||||||
|  |             tool_calls = model_message.tool_calls[0] | ||||||
|  |             tool_arguments = tool_calls.function.arguments | ||||||
|  |             tool_name, tool_call_id = tool_calls.function.name, tool_calls.id | ||||||
|  | 
 | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             raise AgentGenerationError( |             raise AgentGenerationError( | ||||||
|                 f"Error in generating tool call with model:\n{e}" |                 f"Error in generating tool call with model:\n{e}" | ||||||
|  | @ -913,7 +917,7 @@ class CodeAgent(MultiStepAgent): | ||||||
|                 self.input_messages, |                 self.input_messages, | ||||||
|                 stop_sequences=["<end_code>", "Observation:"], |                 stop_sequences=["<end_code>", "Observation:"], | ||||||
|                 **additional_args, |                 **additional_args, | ||||||
|             ) |             ).content | ||||||
|             log_entry.llm_output = llm_output |             log_entry.llm_output = llm_output | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             raise AgentGenerationError(f"Error in generating model output:\n{e}") |             raise AgentGenerationError(f"Error in generating model output:\n{e}") | ||||||
|  |  | ||||||
|  | @ -20,10 +20,16 @@ import os | ||||||
| import random | import random | ||||||
| from copy import deepcopy | from copy import deepcopy | ||||||
| from enum import Enum | from enum import Enum | ||||||
| from typing import Dict, List, Optional, Tuple, Union | from typing import Dict, List, Optional | ||||||
| 
 | 
 | ||||||
| import torch | import torch | ||||||
| from huggingface_hub import InferenceClient | from huggingface_hub import ( | ||||||
|  |     InferenceClient, | ||||||
|  |     ChatCompletionOutputMessage, | ||||||
|  |     ChatCompletionOutputToolCall, | ||||||
|  |     ChatCompletionOutputFunctionDefinition, | ||||||
|  | ) | ||||||
|  | 
 | ||||||
| from transformers import ( | from transformers import ( | ||||||
|     AutoModelForCausalLM, |     AutoModelForCausalLM, | ||||||
|     AutoTokenizer, |     AutoTokenizer, | ||||||
|  | @ -33,7 +39,6 @@ from transformers import ( | ||||||
| import openai | import openai | ||||||
| 
 | 
 | ||||||
| from .tools import Tool | from .tools import Tool | ||||||
| from .utils import parse_json_tool_call |  | ||||||
| 
 | 
 | ||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
| 
 | 
 | ||||||
|  | @ -234,63 +239,46 @@ class HfApiModel(Model): | ||||||
|         model_id: str = "Qwen/Qwen2.5-Coder-32B-Instruct", |         model_id: str = "Qwen/Qwen2.5-Coder-32B-Instruct", | ||||||
|         token: Optional[str] = None, |         token: Optional[str] = None, | ||||||
|         timeout: Optional[int] = 120, |         timeout: Optional[int] = 120, | ||||||
|  |         temperature: float = 0.5, | ||||||
|     ): |     ): | ||||||
|         super().__init__() |         super().__init__() | ||||||
|         self.model_id = model_id |         self.model_id = model_id | ||||||
|         if token is None: |         if token is None: | ||||||
|             token = os.getenv("HF_TOKEN") |             token = os.getenv("HF_TOKEN") | ||||||
|         self.client = InferenceClient(self.model_id, token=token, timeout=timeout) |         self.client = InferenceClient(self.model_id, token=token, timeout=timeout) | ||||||
|  |         self.temperature = temperature | ||||||
| 
 | 
 | ||||||
|     def generate( |     def __call__( | ||||||
|         self, |         self, | ||||||
|         messages: List[Dict[str, str]], |         messages: List[Dict[str, str]], | ||||||
|         stop_sequences: Optional[List[str]] = None, |         stop_sequences: Optional[List[str]] = None, | ||||||
|         grammar: Optional[str] = None, |         grammar: Optional[str] = None, | ||||||
|         max_tokens: int = 1500, |         max_tokens: int = 1500, | ||||||
|  |         tools_to_call_from: Optional[List[Tool]] = None, | ||||||
|     ) -> str: |     ) -> str: | ||||||
|         """Generates a text completion for the given message list""" |  | ||||||
|         messages = get_clean_message_list( |  | ||||||
|             messages, role_conversions=tool_role_conversions |  | ||||||
|         ) |  | ||||||
| 
 |  | ||||||
|         # Send messages to the Hugging Face Inference API |  | ||||||
|         if grammar is not None: |  | ||||||
|             output = self.client.chat_completion( |  | ||||||
|                 messages, |  | ||||||
|                 stop=stop_sequences, |  | ||||||
|                 response_format=grammar, |  | ||||||
|                 max_tokens=max_tokens, |  | ||||||
|             ) |  | ||||||
|         else: |  | ||||||
|             output = self.client.chat.completions.create( |  | ||||||
|                 messages, stop=stop_sequences, max_tokens=max_tokens |  | ||||||
|             ) |  | ||||||
| 
 |  | ||||||
|         response = output.choices[0].message.content |  | ||||||
|         self.last_input_token_count = output.usage.prompt_tokens |  | ||||||
|         self.last_output_token_count = output.usage.completion_tokens |  | ||||||
|         return response |  | ||||||
| 
 |  | ||||||
|     def get_tool_call( |  | ||||||
|         self, |  | ||||||
|         messages: List[Dict[str, str]], |  | ||||||
|         available_tools: List[Tool], |  | ||||||
|         stop_sequences, |  | ||||||
|     ): |  | ||||||
|         """Generates a tool call for the given message list. This method is used only by `ToolCallingAgent`.""" |  | ||||||
|         messages = get_clean_message_list( |         messages = get_clean_message_list( | ||||||
|             messages, role_conversions=tool_role_conversions |             messages, role_conversions=tool_role_conversions | ||||||
|         ) |         ) | ||||||
|  |         if tools_to_call_from: | ||||||
|             response = self.client.chat.completions.create( |             response = self.client.chat.completions.create( | ||||||
|                 messages=messages, |                 messages=messages, | ||||||
|             tools=[get_json_schema(tool) for tool in available_tools], |                 tools=[get_json_schema(tool) for tool in tools_to_call_from], | ||||||
|                 tool_choice="auto", |                 tool_choice="auto", | ||||||
|                 stop=stop_sequences, |                 stop=stop_sequences, | ||||||
|  |                 max_tokens=max_tokens, | ||||||
|  |                 temperature=self.temperature, | ||||||
|  |             ) | ||||||
|  |         else: | ||||||
|  |             response = self.client.chat.completions.create( | ||||||
|  |                 model=self.model_id, | ||||||
|  |                 messages=messages, | ||||||
|  |                 stop=stop_sequences, | ||||||
|  |                 max_tokens=max_tokens, | ||||||
|  |                 temperature=self.temperature, | ||||||
|             ) |             ) | ||||||
|         tool_call = response.choices[0].message.tool_calls[0] |  | ||||||
|         self.last_input_token_count = response.usage.prompt_tokens |         self.last_input_token_count = response.usage.prompt_tokens | ||||||
|         self.last_output_token_count = response.usage.completion_tokens |         self.last_output_token_count = response.usage.completion_tokens | ||||||
|         return tool_call.function.name, tool_call.function.arguments, tool_call.id |         return response.choices[0].message | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class TransformersModel(Model): | class TransformersModel(Model): | ||||||
|  | @ -354,18 +342,27 @@ class TransformersModel(Model): | ||||||
| 
 | 
 | ||||||
|         return StoppingCriteriaList([StopOnStrings(stop_sequences, self.tokenizer)]) |         return StoppingCriteriaList([StopOnStrings(stop_sequences, self.tokenizer)]) | ||||||
| 
 | 
 | ||||||
|     def generate( |     def __call__( | ||||||
|         self, |         self, | ||||||
|         messages: List[Dict[str, str]], |         messages: List[Dict[str, str]], | ||||||
|         stop_sequences: Optional[List[str]] = None, |         stop_sequences: Optional[List[str]] = None, | ||||||
|         grammar: Optional[str] = None, |         grammar: Optional[str] = None, | ||||||
|         max_tokens: int = 1500, |         max_tokens: int = 1500, | ||||||
|  |         tools_to_call_from: Optional[List[Tool]] = None, | ||||||
|     ) -> str: |     ) -> str: | ||||||
|         messages = get_clean_message_list( |         messages = get_clean_message_list( | ||||||
|             messages, role_conversions=tool_role_conversions |             messages, role_conversions=tool_role_conversions | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         # Get LLM output |         if tools_to_call_from is not None: | ||||||
|  |             prompt_tensor = self.tokenizer.apply_chat_template( | ||||||
|  |                 messages, | ||||||
|  |                 tools=[get_json_schema(tool) for tool in tools_to_call_from], | ||||||
|  |                 return_tensors="pt", | ||||||
|  |                 return_dict=True, | ||||||
|  |                 add_generation_prompt=True, | ||||||
|  |             ) | ||||||
|  |         else: | ||||||
|             prompt_tensor = self.tokenizer.apply_chat_template( |             prompt_tensor = self.tokenizer.apply_chat_template( | ||||||
|                 messages, |                 messages, | ||||||
|                 return_tensors="pt", |                 return_tensors="pt", | ||||||
|  | @ -382,56 +379,31 @@ class TransformersModel(Model): | ||||||
|             ), |             ), | ||||||
|         ) |         ) | ||||||
|         generated_tokens = out[0, count_prompt_tokens:] |         generated_tokens = out[0, count_prompt_tokens:] | ||||||
|         response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True) |         output = self.tokenizer.decode(generated_tokens, skip_special_tokens=True) | ||||||
| 
 | 
 | ||||||
|         self.last_input_token_count = count_prompt_tokens |         self.last_input_token_count = count_prompt_tokens | ||||||
|         self.last_output_token_count = len(generated_tokens) |         self.last_output_token_count = len(generated_tokens) | ||||||
| 
 | 
 | ||||||
|         if stop_sequences is not None: |         if stop_sequences is not None: | ||||||
|             response = remove_stop_sequences(response, stop_sequences) |             output = remove_stop_sequences(output, stop_sequences) | ||||||
|         return response |  | ||||||
| 
 | 
 | ||||||
|     def get_tool_call( |         if tools_to_call_from is None: | ||||||
|         self, |             return ChatCompletionOutputMessage(role="assistant", content=output) | ||||||
|         messages: List[Dict[str, str]], |         else: | ||||||
|         available_tools: List[Tool], |             tool_name, tool_arguments = json.load(output) | ||||||
|         stop_sequences: Optional[List[str]] = None, |             return ChatCompletionOutputMessage( | ||||||
|         max_tokens: int = 500, |                 role="assistant", | ||||||
|     ) -> Tuple[str, Union[str, None], str]: |                 content="", | ||||||
|         messages = get_clean_message_list( |                 tool_calls=[ | ||||||
|             messages, role_conversions=tool_role_conversions |                     ChatCompletionOutputToolCall( | ||||||
|         ) |                         id="".join(random.choices("0123456789", k=5)), | ||||||
| 
 |                         type="function", | ||||||
|         prompt = self.tokenizer.apply_chat_template( |                         function=ChatCompletionOutputFunctionDefinition( | ||||||
|             messages, |                             name=tool_name, arguments=tool_arguments | ||||||
|             tools=[get_json_schema(tool) for tool in available_tools], |  | ||||||
|             return_tensors="pt", |  | ||||||
|             return_dict=True, |  | ||||||
|             add_generation_prompt=True, |  | ||||||
|         ) |  | ||||||
|         prompt = prompt.to(self.model.device) |  | ||||||
|         count_prompt_tokens = prompt["input_ids"].shape[1] |  | ||||||
| 
 |  | ||||||
|         out = self.model.generate( |  | ||||||
|             **prompt, |  | ||||||
|             max_new_tokens=max_tokens, |  | ||||||
|             stopping_criteria=( |  | ||||||
|                 self.make_stopping_criteria(stop_sequences) if stop_sequences else None |  | ||||||
|                         ), |                         ), | ||||||
|                     ) |                     ) | ||||||
|         generated_tokens = out[0, count_prompt_tokens:] |                 ], | ||||||
|         response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True) |             ) | ||||||
| 
 |  | ||||||
|         self.last_input_token_count = count_prompt_tokens |  | ||||||
|         self.last_output_token_count = len(generated_tokens) |  | ||||||
| 
 |  | ||||||
|         if stop_sequences is not None: |  | ||||||
|             response = remove_stop_sequences(response, stop_sequences) |  | ||||||
| 
 |  | ||||||
|         tool_name, tool_input = parse_json_tool_call(response) |  | ||||||
|         call_id = "".join(random.choices("0123456789", k=5)) |  | ||||||
| 
 |  | ||||||
|         return tool_name, tool_input, call_id |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class LiteLLMModel(Model): | class LiteLLMModel(Model): | ||||||
|  | @ -460,38 +432,16 @@ class LiteLLMModel(Model): | ||||||
|         stop_sequences: Optional[List[str]] = None, |         stop_sequences: Optional[List[str]] = None, | ||||||
|         grammar: Optional[str] = None, |         grammar: Optional[str] = None, | ||||||
|         max_tokens: int = 1500, |         max_tokens: int = 1500, | ||||||
|  |         tools_to_call_from: Optional[List[Tool]] = None, | ||||||
|     ) -> str: |     ) -> str: | ||||||
|         messages = get_clean_message_list( |         messages = get_clean_message_list( | ||||||
|             messages, role_conversions=tool_role_conversions |             messages, role_conversions=tool_role_conversions | ||||||
|         ) |         ) | ||||||
| 
 |         if tools_to_call_from: | ||||||
|             response = litellm.completion( |             response = litellm.completion( | ||||||
|                 model=self.model_id, |                 model=self.model_id, | ||||||
|                 messages=messages, |                 messages=messages, | ||||||
|             stop=stop_sequences, |                 tools=[get_json_schema(tool) for tool in tools_to_call_from], | ||||||
|             max_tokens=max_tokens, |  | ||||||
|             api_base=self.api_base, |  | ||||||
|             api_key=self.api_key, |  | ||||||
|             **self.kwargs, |  | ||||||
|         ) |  | ||||||
|         self.last_input_token_count = response.usage.prompt_tokens |  | ||||||
|         self.last_output_token_count = response.usage.completion_tokens |  | ||||||
|         return response.choices[0].message.content |  | ||||||
| 
 |  | ||||||
|     def get_tool_call( |  | ||||||
|         self, |  | ||||||
|         messages: List[Dict[str, str]], |  | ||||||
|         available_tools: List[Tool], |  | ||||||
|         stop_sequences: Optional[List[str]] = None, |  | ||||||
|         max_tokens: int = 1500, |  | ||||||
|     ): |  | ||||||
|         messages = get_clean_message_list( |  | ||||||
|             messages, role_conversions=tool_role_conversions |  | ||||||
|         ) |  | ||||||
|         response = litellm.completion( |  | ||||||
|             model=self.model_id, |  | ||||||
|             messages=messages, |  | ||||||
|             tools=[get_json_schema(tool) for tool in available_tools], |  | ||||||
|                 tool_choice="required", |                 tool_choice="required", | ||||||
|                 stop=stop_sequences, |                 stop=stop_sequences, | ||||||
|                 max_tokens=max_tokens, |                 max_tokens=max_tokens, | ||||||
|  | @ -499,11 +449,19 @@ class LiteLLMModel(Model): | ||||||
|                 api_key=self.api_key, |                 api_key=self.api_key, | ||||||
|                 **self.kwargs, |                 **self.kwargs, | ||||||
|             ) |             ) | ||||||
|         tool_calls = response.choices[0].message.tool_calls[0] |         else: | ||||||
|  |             response = litellm.completion( | ||||||
|  |                 model=self.model_id, | ||||||
|  |                 messages=messages, | ||||||
|  |                 stop=stop_sequences, | ||||||
|  |                 max_tokens=max_tokens, | ||||||
|  |                 api_base=self.api_base, | ||||||
|  |                 api_key=self.api_key, | ||||||
|  |                 **self.kwargs, | ||||||
|  |             ) | ||||||
|         self.last_input_token_count = response.usage.prompt_tokens |         self.last_input_token_count = response.usage.prompt_tokens | ||||||
|         self.last_output_token_count = response.usage.completion_tokens |         self.last_output_token_count = response.usage.completion_tokens | ||||||
|         arguments = json.loads(tool_calls.function.arguments) |         return response.choices[0].message | ||||||
|         return tool_calls.function.name, arguments, tool_calls.id |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class OpenAIServerModel(Model): | class OpenAIServerModel(Model): | ||||||
|  | @ -539,64 +497,40 @@ class OpenAIServerModel(Model): | ||||||
|         self.temperature = temperature |         self.temperature = temperature | ||||||
|         self.kwargs = kwargs |         self.kwargs = kwargs | ||||||
| 
 | 
 | ||||||
|     def generate( |     def __call__( | ||||||
|         self, |         self, | ||||||
|         messages: List[Dict[str, str]], |         messages: List[Dict[str, str]], | ||||||
|         stop_sequences: Optional[List[str]] = None, |         stop_sequences: Optional[List[str]] = None, | ||||||
|         grammar: Optional[str] = None, |         grammar: Optional[str] = None, | ||||||
|         max_tokens: int = 1500, |         max_tokens: int = 1500, | ||||||
|  |         tools_to_call_from: Optional[List[Tool]] = None, | ||||||
|     ) -> str: |     ) -> str: | ||||||
|         """Generates a text completion for the given message list""" |  | ||||||
|         messages = get_clean_message_list( |         messages = get_clean_message_list( | ||||||
|             messages, role_conversions=tool_role_conversions |             messages, role_conversions=tool_role_conversions | ||||||
|         ) |         ) | ||||||
| 
 |         if tools_to_call_from: | ||||||
|             response = self.client.chat.completions.create( |             response = self.client.chat.completions.create( | ||||||
|                 model=self.model_id, |                 model=self.model_id, | ||||||
|                 messages=messages, |                 messages=messages, | ||||||
|             stop=stop_sequences, |                 tools=[get_json_schema(tool) for tool in tools_to_call_from], | ||||||
|             max_tokens=max_tokens, |  | ||||||
|             temperature=self.temperature, |  | ||||||
|             **self.kwargs, |  | ||||||
|         ) |  | ||||||
| 
 |  | ||||||
|         self.last_input_token_count = response.usage.prompt_tokens |  | ||||||
|         self.last_output_token_count = response.usage.completion_tokens |  | ||||||
|         return response.choices[0].message.content |  | ||||||
| 
 |  | ||||||
|     def get_tool_call( |  | ||||||
|         self, |  | ||||||
|         messages: List[Dict[str, str]], |  | ||||||
|         available_tools: List[Tool], |  | ||||||
|         stop_sequences: Optional[List[str]] = None, |  | ||||||
|         max_tokens: int = 500, |  | ||||||
|     ) -> Tuple[str, Union[str, Dict], str]: |  | ||||||
|         """Generates a tool call for the given message list""" |  | ||||||
|         messages = get_clean_message_list( |  | ||||||
|             messages, role_conversions=tool_role_conversions |  | ||||||
|         ) |  | ||||||
| 
 |  | ||||||
|         response = self.client.chat.completions.create( |  | ||||||
|             model=self.model_id, |  | ||||||
|             messages=messages, |  | ||||||
|             tools=[get_json_schema(tool) for tool in available_tools], |  | ||||||
|                 tool_choice="auto", |                 tool_choice="auto", | ||||||
|                 stop=stop_sequences, |                 stop=stop_sequences, | ||||||
|                 max_tokens=max_tokens, |                 max_tokens=max_tokens, | ||||||
|                 temperature=self.temperature, |                 temperature=self.temperature, | ||||||
|                 **self.kwargs, |                 **self.kwargs, | ||||||
|             ) |             ) | ||||||
| 
 |         else: | ||||||
|         tool_calls = response.choices[0].message.tool_calls[0] |             response = self.client.chat.completions.create( | ||||||
|  |                 model=self.model_id, | ||||||
|  |                 messages=messages, | ||||||
|  |                 stop=stop_sequences, | ||||||
|  |                 max_tokens=max_tokens, | ||||||
|  |                 temperature=self.temperature, | ||||||
|  |                 **self.kwargs, | ||||||
|  |             ) | ||||||
|         self.last_input_token_count = response.usage.prompt_tokens |         self.last_input_token_count = response.usage.prompt_tokens | ||||||
|         self.last_output_token_count = response.usage.completion_tokens |         self.last_output_token_count = response.usage.completion_tokens | ||||||
| 
 |         return response.choices[0].message | ||||||
|         try: |  | ||||||
|             arguments = json.loads(tool_calls.function.arguments) |  | ||||||
|         except json.JSONDecodeError: |  | ||||||
|             arguments = tool_calls.function.arguments |  | ||||||
| 
 |  | ||||||
|         return tool_calls.function.name, arguments, tool_calls.id |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| __all__ = [ | __all__ = [ | ||||||
|  |  | ||||||
|  | @ -30,6 +30,11 @@ from smolagents.agents import ( | ||||||
| from smolagents.default_tools import PythonInterpreterTool | from smolagents.default_tools import PythonInterpreterTool | ||||||
| from smolagents.tools import tool | from smolagents.tools import tool | ||||||
| from smolagents.types import AgentImage, AgentText | from smolagents.types import AgentImage, AgentText | ||||||
|  | from huggingface_hub import ( | ||||||
|  |     ChatCompletionOutputMessage, | ||||||
|  |     ChatCompletionOutputToolCall, | ||||||
|  |     ChatCompletionOutputFunctionDefinition, | ||||||
|  | ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def get_new_path(suffix="") -> str: | def get_new_path(suffix="") -> str: | ||||||
|  | @ -38,54 +43,106 @@ def get_new_path(suffix="") -> str: | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class FakeToolCallModel: | class FakeToolCallModel: | ||||||
|     def get_tool_call( |     def __call__( | ||||||
|         self, messages, available_tools, stop_sequences=None, grammar=None |         self, messages, tools_to_call_from=None, stop_sequences=None, grammar=None | ||||||
|     ): |     ): | ||||||
|         if len(messages) < 3: |         if len(messages) < 3: | ||||||
|             return "python_interpreter", {"code": "2*3.6452"}, "call_0" |             return ChatCompletionOutputMessage( | ||||||
|  |                 role="assistant", | ||||||
|  |                 content="", | ||||||
|  |                 tool_calls=[ | ||||||
|  |                     ChatCompletionOutputToolCall( | ||||||
|  |                         id="call_0", | ||||||
|  |                         type="function", | ||||||
|  |                         function=ChatCompletionOutputFunctionDefinition( | ||||||
|  |                             name="python_interpreter", arguments={"code": "2*3.6452"} | ||||||
|  |                         ), | ||||||
|  |                     ) | ||||||
|  |                 ], | ||||||
|  |             ) | ||||||
|         else: |         else: | ||||||
|             return "final_answer", {"answer": "7.2904"}, "call_1" |             return ChatCompletionOutputMessage( | ||||||
|  |                 role="assistant", | ||||||
|  |                 content="", | ||||||
|  |                 tool_calls=[ | ||||||
|  |                     ChatCompletionOutputToolCall( | ||||||
|  |                         id="call_1", | ||||||
|  |                         type="function", | ||||||
|  |                         function=ChatCompletionOutputFunctionDefinition( | ||||||
|  |                             name="final_answer", arguments={"answer": "7.2904"} | ||||||
|  |                         ), | ||||||
|  |                     ) | ||||||
|  |                 ], | ||||||
|  |             ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class FakeToolCallModelImage: | class FakeToolCallModelImage: | ||||||
|     def get_tool_call( |     def __call__( | ||||||
|         self, messages, available_tools, stop_sequences=None, grammar=None |         self, messages, tools_to_call_from=None, stop_sequences=None, grammar=None | ||||||
|     ): |     ): | ||||||
|         if len(messages) < 3: |         if len(messages) < 3: | ||||||
|             return ( |             return ChatCompletionOutputMessage( | ||||||
|                 "fake_image_generation_tool", |                 role="assistant", | ||||||
|                 {"prompt": "An image of a cat"}, |                 content="", | ||||||
|                 "call_0", |                 tool_calls=[ | ||||||
|  |                     ChatCompletionOutputToolCall( | ||||||
|  |                         id="call_0", | ||||||
|  |                         type="function", | ||||||
|  |                         function=ChatCompletionOutputFunctionDefinition( | ||||||
|  |                             name="fake_image_generation_tool", | ||||||
|  |                             arguments={"prompt": "An image of a cat"}, | ||||||
|  |                         ), | ||||||
|  |                     ) | ||||||
|  |                 ], | ||||||
|  |             ) | ||||||
|  |         else: | ||||||
|  |             return ChatCompletionOutputMessage( | ||||||
|  |                 role="assistant", | ||||||
|  |                 content="", | ||||||
|  |                 tool_calls=[ | ||||||
|  |                     ChatCompletionOutputToolCall( | ||||||
|  |                         id="call_1", | ||||||
|  |                         type="function", | ||||||
|  |                         function=ChatCompletionOutputFunctionDefinition( | ||||||
|  |                             name="final_answer", arguments="image.png" | ||||||
|  |                         ), | ||||||
|  |                     ) | ||||||
|  |                 ], | ||||||
|             ) |             ) | ||||||
| 
 |  | ||||||
|         else:  # We're at step 2 |  | ||||||
|             return "final_answer", "image.png", "call_1" |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def fake_code_model(messages, stop_sequences=None, grammar=None) -> str: | def fake_code_model(messages, stop_sequences=None, grammar=None) -> str: | ||||||
|     prompt = str(messages) |     prompt = str(messages) | ||||||
|     if "special_marker" not in prompt: |     if "special_marker" not in prompt: | ||||||
|         return """ |         return ChatCompletionOutputMessage( | ||||||
|  |             role="assistant", | ||||||
|  |             content=""" | ||||||
| Thought: I should multiply 2 by 3.6452. special_marker | Thought: I should multiply 2 by 3.6452. special_marker | ||||||
| Code: | Code: | ||||||
| ```py | ```py | ||||||
| result = 2**3.6452 | result = 2**3.6452 | ||||||
| ```<end_code> | ```<end_code> | ||||||
| """ | """, | ||||||
|  |         ) | ||||||
|     else:  # We're at step 2 |     else:  # We're at step 2 | ||||||
|         return """ |         return ChatCompletionOutputMessage( | ||||||
|  |             role="assistant", | ||||||
|  |             content=""" | ||||||
| Thought: I can now answer the initial question | Thought: I can now answer the initial question | ||||||
| Code: | Code: | ||||||
| ```py | ```py | ||||||
| final_answer(7.2904) | final_answer(7.2904) | ||||||
| ```<end_code> | ```<end_code> | ||||||
| """ | """, | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def fake_code_model_error(messages, stop_sequences=None) -> str: | def fake_code_model_error(messages, stop_sequences=None) -> str: | ||||||
|     prompt = str(messages) |     prompt = str(messages) | ||||||
|     if "special_marker" not in prompt: |     if "special_marker" not in prompt: | ||||||
|         return """ |         return ChatCompletionOutputMessage( | ||||||
|  |             role="assistant", | ||||||
|  |             content=""" | ||||||
| Thought: I should multiply 2 by 3.6452. special_marker | Thought: I should multiply 2 by 3.6452. special_marker | ||||||
| Code: | Code: | ||||||
| ```py | ```py | ||||||
|  | @ -94,21 +151,27 @@ b = a * 2 | ||||||
| print = 2 | print = 2 | ||||||
| print("Ok, calculation done!") | print("Ok, calculation done!") | ||||||
| ```<end_code> | ```<end_code> | ||||||
| """ | """, | ||||||
|  |         ) | ||||||
|     else:  # We're at step 2 |     else:  # We're at step 2 | ||||||
|         return """ |         return ChatCompletionOutputMessage( | ||||||
|  |             role="assistant", | ||||||
|  |             content=""" | ||||||
| Thought: I can now answer the initial question | Thought: I can now answer the initial question | ||||||
| Code: | Code: | ||||||
| ```py | ```py | ||||||
| final_answer("got an error") | final_answer("got an error") | ||||||
| ```<end_code> | ```<end_code> | ||||||
| """ | """, | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def fake_code_model_syntax_error(messages, stop_sequences=None) -> str: | def fake_code_model_syntax_error(messages, stop_sequences=None) -> str: | ||||||
|     prompt = str(messages) |     prompt = str(messages) | ||||||
|     if "special_marker" not in prompt: |     if "special_marker" not in prompt: | ||||||
|         return """ |         return ChatCompletionOutputMessage( | ||||||
|  |             role="assistant", | ||||||
|  |             content=""" | ||||||
| Thought: I should multiply 2 by 3.6452. special_marker | Thought: I should multiply 2 by 3.6452. special_marker | ||||||
| Code: | Code: | ||||||
| ```py | ```py | ||||||
|  | @ -117,32 +180,41 @@ b = a * 2 | ||||||
|     print("Failing due to unexpected indent") |     print("Failing due to unexpected indent") | ||||||
| print("Ok, calculation done!") | print("Ok, calculation done!") | ||||||
| ```<end_code> | ```<end_code> | ||||||
| """ | """, | ||||||
|  |         ) | ||||||
|     else:  # We're at step 2 |     else:  # We're at step 2 | ||||||
|         return """ |         return ChatCompletionOutputMessage( | ||||||
|  |             role="assistant", | ||||||
|  |             content=""" | ||||||
| Thought: I can now answer the initial question | Thought: I can now answer the initial question | ||||||
| Code: | Code: | ||||||
| ```py | ```py | ||||||
| final_answer("got an error") | final_answer("got an error") | ||||||
| ```<end_code> | ```<end_code> | ||||||
| """ | """, | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def fake_code_model_import(messages, stop_sequences=None) -> str: | def fake_code_model_import(messages, stop_sequences=None) -> str: | ||||||
|     return """ |     return ChatCompletionOutputMessage( | ||||||
|  |         role="assistant", | ||||||
|  |         content=""" | ||||||
| Thought: I can answer the question | Thought: I can answer the question | ||||||
| Code: | Code: | ||||||
| ```py | ```py | ||||||
| import numpy as np | import numpy as np | ||||||
| final_answer("got an error") | final_answer("got an error") | ||||||
| ```<end_code> | ```<end_code> | ||||||
| """ | """, | ||||||
|  |     ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def fake_code_functiondef(messages, stop_sequences=None) -> str: | def fake_code_functiondef(messages, stop_sequences=None) -> str: | ||||||
|     prompt = str(messages) |     prompt = str(messages) | ||||||
|     if "special_marker" not in prompt: |     if "special_marker" not in prompt: | ||||||
|         return """ |         return ChatCompletionOutputMessage( | ||||||
|  |             role="assistant", | ||||||
|  |             content=""" | ||||||
| Thought: Let's define the function. special_marker | Thought: Let's define the function. special_marker | ||||||
| Code: | Code: | ||||||
| ```py | ```py | ||||||
|  | @ -151,9 +223,12 @@ import numpy as np | ||||||
| def moving_average(x, w): | def moving_average(x, w): | ||||||
|     return np.convolve(x, np.ones(w), 'valid') / w |     return np.convolve(x, np.ones(w), 'valid') / w | ||||||
| ```<end_code> | ```<end_code> | ||||||
| """ | """, | ||||||
|  |         ) | ||||||
|     else:  # We're at step 2 |     else:  # We're at step 2 | ||||||
|         return """ |         return ChatCompletionOutputMessage( | ||||||
|  |             role="assistant", | ||||||
|  |             content=""" | ||||||
| Thought: I can now answer the initial question | Thought: I can now answer the initial question | ||||||
| Code: | Code: | ||||||
| ```py | ```py | ||||||
|  | @ -161,29 +236,36 @@ x, w = [0, 1, 2, 3, 4, 5], 2 | ||||||
| res = moving_average(x, w) | res = moving_average(x, w) | ||||||
| final_answer(res) | final_answer(res) | ||||||
| ```<end_code> | ```<end_code> | ||||||
| """ | """, | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def fake_code_model_single_step(messages, stop_sequences=None, grammar=None) -> str: | def fake_code_model_single_step(messages, stop_sequences=None, grammar=None) -> str: | ||||||
|     return """ |     return ChatCompletionOutputMessage( | ||||||
|  |         role="assistant", | ||||||
|  |         content=""" | ||||||
| Thought: I should multiply 2 by 3.6452. special_marker | Thought: I should multiply 2 by 3.6452. special_marker | ||||||
| Code: | Code: | ||||||
| ```py | ```py | ||||||
| result = python_interpreter(code="2*3.6452") | result = python_interpreter(code="2*3.6452") | ||||||
| final_answer(result) | final_answer(result) | ||||||
| ``` | ``` | ||||||
| """ | """, | ||||||
|  |     ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def fake_code_model_no_return(messages, stop_sequences=None, grammar=None) -> str: | def fake_code_model_no_return(messages, stop_sequences=None, grammar=None) -> str: | ||||||
|     return """ |     return ChatCompletionOutputMessage( | ||||||
|  |         role="assistant", | ||||||
|  |         content=""" | ||||||
| Thought: I should multiply 2 by 3.6452. special_marker | Thought: I should multiply 2 by 3.6452. special_marker | ||||||
| Code: | Code: | ||||||
| ```py | ```py | ||||||
| result = python_interpreter(code="2*3.6452") | result = python_interpreter(code="2*3.6452") | ||||||
| print(result) | print(result) | ||||||
| ``` | ``` | ||||||
| """ | """, | ||||||
|  |     ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class AgentTests(unittest.TestCase): | class AgentTests(unittest.TestCase): | ||||||
|  | @ -360,52 +442,92 @@ class AgentTests(unittest.TestCase): | ||||||
| 
 | 
 | ||||||
|     def test_multiagents(self): |     def test_multiagents(self): | ||||||
|         class FakeModelMultiagentsManagerAgent: |         class FakeModelMultiagentsManagerAgent: | ||||||
|             def __call__(self, messages, stop_sequences=None, grammar=None): |             def __call__( | ||||||
|  |                 self, | ||||||
|  |                 messages, | ||||||
|  |                 stop_sequences=None, | ||||||
|  |                 grammar=None, | ||||||
|  |                 tools_to_call_from=None, | ||||||
|  |             ): | ||||||
|  |                 if tools_to_call_from is not None: | ||||||
|                     if len(messages) < 3: |                     if len(messages) < 3: | ||||||
|                     return """ |                         return ChatCompletionOutputMessage( | ||||||
|  |                             role="assistant", | ||||||
|  |                             content="", | ||||||
|  |                             tool_calls=[ | ||||||
|  |                                 ChatCompletionOutputToolCall( | ||||||
|  |                                     id="call_0", | ||||||
|  |                                     type="function", | ||||||
|  |                                     function=ChatCompletionOutputFunctionDefinition( | ||||||
|  |                                         name="search_agent", | ||||||
|  |                                         arguments="Who is the current US president?", | ||||||
|  |                                     ), | ||||||
|  |                                 ) | ||||||
|  |                             ], | ||||||
|  |                         ) | ||||||
|  |                     else: | ||||||
|  |                         assert "Report on the current US president" in str(messages) | ||||||
|  |                         return ChatCompletionOutputMessage( | ||||||
|  |                             role="assistant", | ||||||
|  |                             content="", | ||||||
|  |                             tool_calls=[ | ||||||
|  |                                 ChatCompletionOutputToolCall( | ||||||
|  |                                     id="call_0", | ||||||
|  |                                     type="function", | ||||||
|  |                                     function=ChatCompletionOutputFunctionDefinition( | ||||||
|  |                                         name="final_answer", arguments="Final report." | ||||||
|  |                                     ), | ||||||
|  |                                 ) | ||||||
|  |                             ], | ||||||
|  |                         ) | ||||||
|  |                 else: | ||||||
|  |                     if len(messages) < 3: | ||||||
|  |                         return ChatCompletionOutputMessage( | ||||||
|  |                             role="assistant", | ||||||
|  |                             content=""" | ||||||
| Thought: Let's call our search agent. | Thought: Let's call our search agent. | ||||||
| Code: | Code: | ||||||
| ```py | ```py | ||||||
| result = search_agent("Who is the current US president?") | result = search_agent("Who is the current US president?") | ||||||
| ```<end_code> | ```<end_code> | ||||||
| """ | """, | ||||||
|  |                         ) | ||||||
|                     else: |                     else: | ||||||
|                         assert "Report on the current US president" in str(messages) |                         assert "Report on the current US president" in str(messages) | ||||||
|                     return """ |                         return ChatCompletionOutputMessage( | ||||||
|  |                             role="assistant", | ||||||
|  |                             content=""" | ||||||
| Thought: Let's return the report. | Thought: Let's return the report. | ||||||
| Code: | Code: | ||||||
| ```py | ```py | ||||||
| final_answer("Final report.") | final_answer("Final report.") | ||||||
| ```<end_code> | ```<end_code> | ||||||
| """ | """, | ||||||
| 
 |  | ||||||
|             def get_tool_call( |  | ||||||
|                 self, messages, available_tools, stop_sequences=None, grammar=None |  | ||||||
|             ): |  | ||||||
|                 if len(messages) < 3: |  | ||||||
|                     return ( |  | ||||||
|                         "search_agent", |  | ||||||
|                         "Who is the current US president?", |  | ||||||
|                         "call_0", |  | ||||||
|                     ) |  | ||||||
|                 else: |  | ||||||
|                     assert "Report on the current US president" in str(messages) |  | ||||||
|                     return ( |  | ||||||
|                         "final_answer", |  | ||||||
|                         "Final report.", |  | ||||||
|                         "call_0", |  | ||||||
|                         ) |                         ) | ||||||
| 
 | 
 | ||||||
|         manager_model = FakeModelMultiagentsManagerAgent() |         manager_model = FakeModelMultiagentsManagerAgent() | ||||||
| 
 | 
 | ||||||
|         class FakeModelMultiagentsManagedAgent: |         class FakeModelMultiagentsManagedAgent: | ||||||
|             def get_tool_call( |             def __call__( | ||||||
|                 self, messages, available_tools, stop_sequences=None, grammar=None |                 self, | ||||||
|  |                 messages, | ||||||
|  |                 tools_to_call_from=None, | ||||||
|  |                 stop_sequences=None, | ||||||
|  |                 grammar=None, | ||||||
|             ): |             ): | ||||||
|                 return ( |                 return ChatCompletionOutputMessage( | ||||||
|                     "final_answer", |                     role="assistant", | ||||||
|                     {"report": "Report on the current US president"}, |                     content="", | ||||||
|                     "call_0", |                     tool_calls=[ | ||||||
|  |                         ChatCompletionOutputToolCall( | ||||||
|  |                             id="call_0", | ||||||
|  |                             type="function", | ||||||
|  |                             function=ChatCompletionOutputFunctionDefinition( | ||||||
|  |                                 name="final_answer", | ||||||
|  |                                 arguments="Report on the current US president", | ||||||
|  |                             ), | ||||||
|  |                         ) | ||||||
|  |                     ], | ||||||
|                 ) |                 ) | ||||||
| 
 | 
 | ||||||
|         managed_model = FakeModelMultiagentsManagedAgent() |         managed_model = FakeModelMultiagentsManagedAgent() | ||||||
|  | @ -443,13 +565,16 @@ final_answer("Final report.") | ||||||
| 
 | 
 | ||||||
|     def test_code_nontrivial_final_answer_works(self): |     def test_code_nontrivial_final_answer_works(self): | ||||||
|         def fake_code_model_final_answer(messages, stop_sequences=None, grammar=None): |         def fake_code_model_final_answer(messages, stop_sequences=None, grammar=None): | ||||||
|             return """Code: |             return ChatCompletionOutputMessage( | ||||||
|  |                 role="assistant", | ||||||
|  |                 content="""Code: | ||||||
| ```py | ```py | ||||||
| def nested_answer(): | def nested_answer(): | ||||||
|     final_answer("Correct!") |     final_answer("Correct!") | ||||||
| 
 | 
 | ||||||
| nested_answer() | nested_answer() | ||||||
| ```<end_code>""" | ```<end_code>""", | ||||||
|  |             ) | ||||||
| 
 | 
 | ||||||
|         agent = CodeAgent(tools=[], model=fake_code_model_final_answer) |         agent = CodeAgent(tools=[], model=fake_code_model_final_answer) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -92,7 +92,6 @@ class TestDocs: | ||||||
|             raise ValueError(f"Docs directory not found at {cls.docs_dir}") |             raise ValueError(f"Docs directory not found at {cls.docs_dir}") | ||||||
| 
 | 
 | ||||||
|         load_dotenv() |         load_dotenv() | ||||||
|         cls.hf_token = os.getenv("HF_TOKEN") |  | ||||||
| 
 | 
 | ||||||
|         cls.md_files = list(cls.docs_dir.rglob("*.md")) |         cls.md_files = list(cls.docs_dir.rglob("*.md")) | ||||||
|         if not cls.md_files: |         if not cls.md_files: | ||||||
|  | @ -115,6 +114,7 @@ class TestDocs: | ||||||
|             "from_langchain",  # Langchain is not a dependency |             "from_langchain",  # Langchain is not a dependency | ||||||
|             "while llm_should_continue(memory):",  # This is pseudo code |             "while llm_should_continue(memory):",  # This is pseudo code | ||||||
|             "ollama_chat/llama3.2",  # Exclude ollama building in guided tour |             "ollama_chat/llama3.2",  # Exclude ollama building in guided tour | ||||||
|  |             "model = TransformersModel(model_id=model_id)",  # Exclude testing with transformers model | ||||||
|         ] |         ] | ||||||
|         code_blocks = [ |         code_blocks = [ | ||||||
|             block |             block | ||||||
|  | @ -131,10 +131,15 @@ class TestDocs: | ||||||
|             ast.parse(block) |             ast.parse(block) | ||||||
| 
 | 
 | ||||||
|         # Create and execute test script |         # Create and execute test script | ||||||
|  |         print("\n\nCollected code block:==========\n".join(code_blocks)) | ||||||
|         try: |         try: | ||||||
|             code_blocks = [ |             code_blocks = [ | ||||||
|                 block.replace("<YOUR_HUGGINGFACEHUB_API_TOKEN>", self.hf_token).replace( |                 ( | ||||||
|                     "{your_username}", "m-ric" |                     block.replace( | ||||||
|  |                         "<YOUR_HUGGINGFACEHUB_API_TOKEN>", os.getenv("HF_TOKEN") | ||||||
|  |                     ) | ||||||
|  |                     .replace("YOUR_ANTHROPIC_API_KEY", os.getenv("ANTHROPIC_API_KEY")) | ||||||
|  |                     .replace("{your_username}", "m-ric") | ||||||
|                 ) |                 ) | ||||||
|                 for block in code_blocks |                 for block in code_blocks | ||||||
|             ] |             ] | ||||||
|  |  | ||||||
|  | @ -22,42 +22,57 @@ from smolagents import ( | ||||||
|     ToolCallingAgent, |     ToolCallingAgent, | ||||||
|     stream_to_gradio, |     stream_to_gradio, | ||||||
| ) | ) | ||||||
|  | from huggingface_hub import ( | ||||||
|  |     ChatCompletionOutputMessage, | ||||||
|  |     ChatCompletionOutputToolCall, | ||||||
|  |     ChatCompletionOutputFunctionDefinition, | ||||||
|  | ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class MonitoringTester(unittest.TestCase): |  | ||||||
|     def test_code_agent_metrics(self): |  | ||||||
| class FakeLLMModel: | class FakeLLMModel: | ||||||
|     def __init__(self): |     def __init__(self): | ||||||
|         self.last_input_token_count = 10 |         self.last_input_token_count = 10 | ||||||
|         self.last_output_token_count = 20 |         self.last_output_token_count = 20 | ||||||
| 
 | 
 | ||||||
|             def __call__(self, prompt, **kwargs): |     def __call__(self, prompt, tools_to_call_from=None, **kwargs): | ||||||
|                 return """ |         if tools_to_call_from is not None: | ||||||
|  |             return ChatCompletionOutputMessage( | ||||||
|  |                 role="assistant", | ||||||
|  |                 content="", | ||||||
|  |                 tool_calls=[ | ||||||
|  |                     ChatCompletionOutputToolCall( | ||||||
|  |                         id="fake_id", | ||||||
|  |                         type="function", | ||||||
|  |                         function=ChatCompletionOutputFunctionDefinition( | ||||||
|  |                             name="final_answer", arguments={"answer": "image"} | ||||||
|  |                         ), | ||||||
|  |                     ) | ||||||
|  |                 ], | ||||||
|  |             ) | ||||||
|  |         else: | ||||||
|  |             return ChatCompletionOutputMessage( | ||||||
|  |                 role="assistant", | ||||||
|  |                 content=""" | ||||||
| Code: | Code: | ||||||
| ```py | ```py | ||||||
| final_answer('This is the final answer.') | final_answer('This is the final answer.') | ||||||
| ```""" | ```""", | ||||||
|  |             ) | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
|  | class MonitoringTester(unittest.TestCase): | ||||||
|  |     def test_code_agent_metrics(self): | ||||||
|         agent = CodeAgent( |         agent = CodeAgent( | ||||||
|             tools=[], |             tools=[], | ||||||
|             model=FakeLLMModel(), |             model=FakeLLMModel(), | ||||||
|             max_steps=1, |             max_steps=1, | ||||||
|         ) |         ) | ||||||
| 
 |  | ||||||
|         agent.run("Fake task") |         agent.run("Fake task") | ||||||
| 
 | 
 | ||||||
|         self.assertEqual(agent.monitor.total_input_token_count, 10) |         self.assertEqual(agent.monitor.total_input_token_count, 10) | ||||||
|         self.assertEqual(agent.monitor.total_output_token_count, 20) |         self.assertEqual(agent.monitor.total_output_token_count, 20) | ||||||
| 
 | 
 | ||||||
|     def test_json_agent_metrics(self): |     def test_json_agent_metrics(self): | ||||||
|         class FakeLLMModel: |  | ||||||
|             def __init__(self): |  | ||||||
|                 self.last_input_token_count = 10 |  | ||||||
|                 self.last_output_token_count = 20 |  | ||||||
| 
 |  | ||||||
|             def get_tool_call(self, prompt, **kwargs): |  | ||||||
|                 return "final_answer", {"answer": "image"}, "fake_id" |  | ||||||
| 
 |  | ||||||
|         agent = ToolCallingAgent( |         agent = ToolCallingAgent( | ||||||
|             tools=[], |             tools=[], | ||||||
|             model=FakeLLMModel(), |             model=FakeLLMModel(), | ||||||
|  | @ -70,17 +85,19 @@ final_answer('This is the final answer.') | ||||||
|         self.assertEqual(agent.monitor.total_output_token_count, 20) |         self.assertEqual(agent.monitor.total_output_token_count, 20) | ||||||
| 
 | 
 | ||||||
|     def test_code_agent_metrics_max_steps(self): |     def test_code_agent_metrics_max_steps(self): | ||||||
|         class FakeLLMModel: |         class FakeLLMModelMalformedAnswer: | ||||||
|             def __init__(self): |             def __init__(self): | ||||||
|                 self.last_input_token_count = 10 |                 self.last_input_token_count = 10 | ||||||
|                 self.last_output_token_count = 20 |                 self.last_output_token_count = 20 | ||||||
| 
 | 
 | ||||||
|             def __call__(self, prompt, **kwargs): |             def __call__(self, prompt, **kwargs): | ||||||
|                 return "Malformed answer" |                 return ChatCompletionOutputMessage( | ||||||
|  |                     role="assistant", content="Malformed answer" | ||||||
|  |                 ) | ||||||
| 
 | 
 | ||||||
|         agent = CodeAgent( |         agent = CodeAgent( | ||||||
|             tools=[], |             tools=[], | ||||||
|             model=FakeLLMModel(), |             model=FakeLLMModelMalformedAnswer(), | ||||||
|             max_steps=1, |             max_steps=1, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|  | @ -90,7 +107,7 @@ final_answer('This is the final answer.') | ||||||
|         self.assertEqual(agent.monitor.total_output_token_count, 40) |         self.assertEqual(agent.monitor.total_output_token_count, 40) | ||||||
| 
 | 
 | ||||||
|     def test_code_agent_metrics_generation_error(self): |     def test_code_agent_metrics_generation_error(self): | ||||||
|         class FakeLLMModel: |         class FakeLLMModelGenerationException: | ||||||
|             def __init__(self): |             def __init__(self): | ||||||
|                 self.last_input_token_count = 10 |                 self.last_input_token_count = 10 | ||||||
|                 self.last_output_token_count = 20 |                 self.last_output_token_count = 20 | ||||||
|  | @ -102,7 +119,7 @@ final_answer('This is the final answer.') | ||||||
| 
 | 
 | ||||||
|         agent = CodeAgent( |         agent = CodeAgent( | ||||||
|             tools=[], |             tools=[], | ||||||
|             model=FakeLLMModel(), |             model=FakeLLMModelGenerationException(), | ||||||
|             max_steps=1, |             max_steps=1, | ||||||
|         ) |         ) | ||||||
|         agent.run("Fake task") |         agent.run("Fake task") | ||||||
|  | @ -113,16 +130,9 @@ final_answer('This is the final answer.') | ||||||
|         self.assertEqual(agent.monitor.total_output_token_count, 0) |         self.assertEqual(agent.monitor.total_output_token_count, 0) | ||||||
| 
 | 
 | ||||||
|     def test_streaming_agent_text_output(self): |     def test_streaming_agent_text_output(self): | ||||||
|         def dummy_model(prompt, **kwargs): |  | ||||||
|             return """ |  | ||||||
| Code: |  | ||||||
| ```py |  | ||||||
| final_answer('This is the final answer.') |  | ||||||
| ```""" |  | ||||||
| 
 |  | ||||||
|         agent = CodeAgent( |         agent = CodeAgent( | ||||||
|             tools=[], |             tools=[], | ||||||
|             model=dummy_model, |             model=FakeLLMModel(), | ||||||
|             max_steps=1, |             max_steps=1, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|  | @ -135,16 +145,9 @@ final_answer('This is the final answer.') | ||||||
|         self.assertIn("This is the final answer.", final_message.content) |         self.assertIn("This is the final answer.", final_message.content) | ||||||
| 
 | 
 | ||||||
|     def test_streaming_agent_image_output(self): |     def test_streaming_agent_image_output(self): | ||||||
|         class FakeLLM: |  | ||||||
|             def __init__(self): |  | ||||||
|                 pass |  | ||||||
| 
 |  | ||||||
|             def get_tool_call(self, messages, **kwargs): |  | ||||||
|                 return "final_answer", {"answer": "image"}, "fake_id" |  | ||||||
| 
 |  | ||||||
|         agent = ToolCallingAgent( |         agent = ToolCallingAgent( | ||||||
|             tools=[], |             tools=[], | ||||||
|             model=FakeLLM(), |             model=FakeLLMModel(), | ||||||
|             max_steps=1, |             max_steps=1, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue