This commit is contained in:
Aymeric 2024-12-27 16:18:19 +01:00
parent 710fb75559
commit c880f2f5b6
13 changed files with 115 additions and 48 deletions

View File

@ -122,9 +122,9 @@ def sql_engine(query: str) -> str:
Now let us create an agent that leverages this tool. Now let us create an agent that leverages this tool.
We use the CodeAgent, which is transformers.agents main agent class: an agent that writes actions in code and can iterate on previous output according to the ReAct framework. We use the `CodeAgent`, which is transformers.agents main agent class: an agent that writes actions in code and can iterate on previous output according to the ReAct framework.
The model is the LLM that powers the agent system. HfModel allows you to call LLMs using HFs Inference API, either via Serverless or Dedicated endpoint, but you could also use any proprietary API. The model is the LLM that powers the agent system. HfApiModel allows you to call LLMs using HFs Inference API, either via Serverless or Dedicated endpoint, but you could also use any proprietary API.
```py ```py
from smolagents import CodeAgent, HfApiModel from smolagents import CodeAgent, HfApiModel
@ -180,14 +180,14 @@ for table in ["receipts", "waiters"]:
print(updated_description) print(updated_description)
``` ```
Since this request is a bit harder than the previous one, well switch the LLM engine to use the more powerful [Qwen/Qwen2.5-72B-Instruct](https://huggingface.co/Qwen/Qwen2.5-72B-Instruct)! Since this request is a bit harder than the previous one, well switch the LLM engine to use the more powerful [Qwen/Qwen2.5-Coder-32B-Instruct](https://huggingface.co/Qwen/Qwen2.5-Coder-32B-Instruct)!
```py ```py
sql_engine.description = updated_description sql_engine.description = updated_description
agent = CodeAgent( agent = CodeAgent(
tools=[sql_engine], tools=[sql_engine],
model=HfApiModel("Qwen/Qwen2.5-72B-Instruct"), model=HfApiModel("Qwen/Qwen2.5-Coder-32B-Instruct"),
) )
agent.run("Which waiter got more total money from tips?") agent.run("Which waiter got more total money from tips?")

View File

@ -19,7 +19,6 @@ dependencies = [
"pandas>=2.2.3", "pandas>=2.2.3",
"jinja2>=3.1.4", "jinja2>=3.1.4",
"pillow>=11.0.0", "pillow>=11.0.0",
"llama-cpp-python>=0.3.4",
"markdownify>=0.14.1", "markdownify>=0.14.1",
"gradio>=5.8.0", "gradio>=5.8.0",
"duckduckgo-search>=6.3.7", "duckduckgo-search>=6.3.7",
@ -30,9 +29,6 @@ dependencies = [
] ]
[project.optional-dependencies] [project.optional-dependencies]
dev = [
"anthropic",
]
test = [ test = [
"gradio-tools" "gradio-tools"
] ]

View File

@ -389,12 +389,16 @@ class MultiStepAgent:
try: try:
if isinstance(arguments, str): if isinstance(arguments, str):
observation = available_tools[tool_name].__call__(arguments, sanitize_inputs_outputs=True) observation = available_tools[tool_name].__call__(
arguments, sanitize_inputs_outputs=True
)
elif isinstance(arguments, dict): elif isinstance(arguments, dict):
for key, value in arguments.items(): for key, value in arguments.items():
if isinstance(value, str) and value in self.state: if isinstance(value, str) and value in self.state:
arguments[key] = self.state[value] arguments[key] = self.state[value]
observation = available_tools[tool_name].__call__(**arguments, sanitize_inputs_outputs=True) observation = available_tools[tool_name].__call__(
**arguments, sanitize_inputs_outputs=True
)
else: else:
error_msg = f"Arguments passed to tool should be a dict or string: got a {type(arguments)}." error_msg = f"Arguments passed to tool should be a dict or string: got a {type(arguments)}."
raise AgentExecutionError(error_msg) raise AgentExecutionError(error_msg)
@ -774,10 +778,14 @@ class ToolCallingAgent(MultiStepAgent):
isinstance(answer, str) and answer in self.state.keys() isinstance(answer, str) and answer in self.state.keys()
): # if the answer is a state variable, return the value ): # if the answer is a state variable, return the value
final_answer = self.state[answer] final_answer = self.state[answer]
console.print(f"[bold {YELLOW_HEX}]Final answer:[/bold {YELLOW_HEX}] Extracting key '{answer}' from state to return value '{final_answer}'.") console.print(
f"[bold {YELLOW_HEX}]Final answer:[/bold {YELLOW_HEX}] Extracting key '{answer}' from state to return value '{final_answer}'."
)
else: else:
final_answer = answer final_answer = answer
console.print(Text(f"Final answer: {final_answer}", style=f"bold {YELLOW_HEX}")) console.print(
Text(f"Final answer: {final_answer}", style=f"bold {YELLOW_HEX}")
)
log_entry.action_output = final_answer log_entry.action_output = final_answer
return final_answer return final_answer
@ -891,7 +899,12 @@ class CodeAgent(MultiStepAgent):
align="left", align="left",
style="orange", style="orange",
), ),
Syntax(llm_output, lexer="markdown", theme="github-dark", word_wrap=True), Syntax(
llm_output,
lexer="markdown",
theme="github-dark",
word_wrap=True,
),
) )
) )

View File

@ -163,10 +163,12 @@ class DuckDuckGoSearchTool(Tool):
) )
self.ddgs = DDGS() self.ddgs = DDGS()
def forward(self, query: str) -> str: def forward(self, query: str) -> str:
results = self.ddgs.text(query, max_results=10) results = self.ddgs.text(query, max_results=10)
postprocessed_results = [f"[{result['title']}]({result['href']})\n{result['body']}" for result in results] postprocessed_results = [
f"[{result['title']}]({result['href']})\n{result['body']}"
for result in results
]
return "## Search Results\n\n" + "\n\n".join(postprocessed_results) return "## Search Results\n\n" + "\n\n".join(postprocessed_results)
@ -301,7 +303,12 @@ class SpeechToTextTool(PipelineTool):
pre_processor_class = WhisperProcessor pre_processor_class = WhisperProcessor
model_class = WhisperForConditionalGeneration model_class = WhisperForConditionalGeneration
inputs = {"audio": {"type": "audio", "description": "The audio to transcribe. Can be a local path, an url, or a tensor."}} inputs = {
"audio": {
"type": "audio",
"description": "The audio to transcribe. Can be a local path, an url, or a tensor.",
}
}
output_type = "string" output_type = "string"
def encode(self, audio): def encode(self, audio):

View File

@ -110,7 +110,6 @@ locals().update({key: value for key, value in pickle_dict.items()})
execution_logs = "\n".join([str(log) for log in execution.logs.stdout]) execution_logs = "\n".join([str(log) for log in execution.logs.stdout])
console.print(execution_logs) console.print(execution_logs)
execution = self.run_code_raise_errors(code_action) execution = self.run_code_raise_errors(code_action)
execution_logs = "\n".join([str(log) for log in execution.logs.stdout]) execution_logs = "\n".join([str(log) for log in execution.logs.stdout])
if not execution.results: if not execution.results:

View File

@ -129,7 +129,8 @@ def get_clean_message_list(
final_message_list.append(message) final_message_list.append(message)
return final_message_list return final_message_list
class Model():
class Model:
def __init__(self): def __init__(self):
self.last_input_token_count = None self.last_input_token_count = None
self.last_output_token_count = None self.last_output_token_count = None
@ -313,9 +314,16 @@ class TransformersModel(Model):
self.stream = "" self.stream = ""
def __call__(self, input_ids, scores, **kwargs): def __call__(self, input_ids, scores, **kwargs):
generated = self.tokenizer.decode(input_ids[0][-1], skip_special_tokens=True) generated = self.tokenizer.decode(
input_ids[0][-1], skip_special_tokens=True
)
self.stream += generated self.stream += generated
if any([self.stream.endswith(stop_string) for stop_string in self.stop_strings]): if any(
[
self.stream.endswith(stop_string)
for stop_string in self.stop_strings
]
):
return True return True
return False return False
@ -458,7 +466,7 @@ __all__ = [
"MessageRole", "MessageRole",
"tool_role_conversions", "tool_role_conversions",
"get_clean_message_list", "get_clean_message_list",
"HfModel", "Model",
"TransformersModel", "TransformersModel",
"HfApiModel", "HfApiModel",
"LiteLLMModel", "LiteLLMModel",

View File

@ -32,7 +32,7 @@ class Monitor:
def get_total_token_counts(self): def get_total_token_counts(self):
return { return {
"input": self.total_input_token_count, "input": self.total_input_token_count,
"output": self.total_output_token_count "output": self.total_output_token_count,
} }
def reset(self): def reset(self):

View File

@ -1,6 +1,5 @@
import ast import ast
import inspect import inspect
import importlib.util
import builtins import builtins
from typing import Set from typing import Set
import textwrap import textwrap

View File

@ -108,6 +108,7 @@ def validate_after_init(cls):
cls.__init__ = new_init cls.__init__ = new_init
return cls return cls
def _convert_type_hints_to_json_schema(func: Callable) -> Dict: def _convert_type_hints_to_json_schema(func: Callable) -> Dict:
type_hints = get_type_hints(func) type_hints = get_type_hints(func)
signature = inspect.signature(func) signature = inspect.signature(func)
@ -119,10 +120,13 @@ def _convert_type_hints_to_json_schema(func: Callable) -> Dict:
properties[param_name]["nullable"] = True properties[param_name]["nullable"] = True
for param_name in signature.parameters.keys(): for param_name in signature.parameters.keys():
if signature.parameters[param_name].default != inspect.Parameter.empty: if signature.parameters[param_name].default != inspect.Parameter.empty:
if param_name not in properties: # this can happen if the param has no type hint but a default value if (
param_name not in properties
): # this can happen if the param has no type hint but a default value
properties[param_name] = {"nullable": True} properties[param_name] = {"nullable": True}
return properties return properties
AUTHORIZED_TYPES = [ AUTHORIZED_TYPES = [
"string", "string",
"boolean", "boolean",
@ -202,7 +206,10 @@ class Tool:
assert getattr(self, "output_type", None) in AUTHORIZED_TYPES assert getattr(self, "output_type", None) in AUTHORIZED_TYPES
# Validate forward function signature, except for PipelineTool # Validate forward function signature, except for PipelineTool
if not (hasattr(self, "is_pipeline_tool") and getattr(self, "is_pipeline_tool") is True): if not (
hasattr(self, "is_pipeline_tool")
and getattr(self, "is_pipeline_tool") is True
):
signature = inspect.signature(self.forward) signature = inspect.signature(self.forward)
if not set(signature.parameters.keys()) == set(self.inputs.keys()): if not set(signature.parameters.keys()) == set(self.inputs.keys()):
@ -213,9 +220,13 @@ class Tool:
json_schema = _convert_type_hints_to_json_schema(self.forward) json_schema = _convert_type_hints_to_json_schema(self.forward)
for key, value in self.inputs.items(): for key, value in self.inputs.items():
if "nullable" in value: if "nullable" in value:
assert (key in json_schema and "nullable" in json_schema[key]), f"Nullable argument '{key}' in inputs should have key 'nullable' set to True in function signature." assert (
key in json_schema and "nullable" in json_schema[key]
), f"Nullable argument '{key}' in inputs should have key 'nullable' set to True in function signature."
if key in json_schema and "nullable" in json_schema[key]: if key in json_schema and "nullable" in json_schema[key]:
assert "nullable" in value, f"Nullable argument '{key}' in function signature should have key 'nullable' set to True in inputs." assert (
"nullable" in value
), f"Nullable argument '{key}' in function signature should have key 'nullable' set to True in inputs."
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
return NotImplementedError("Write this method in your subclass of `Tool`.") return NotImplementedError("Write this method in your subclass of `Tool`.")

View File

@ -249,7 +249,11 @@ class AgentAudio(AgentType, str):
AGENT_TYPE_MAPPING = {"string": AgentText, "image": AgentImage, "audio": AgentAudio} AGENT_TYPE_MAPPING = {"string": AgentText, "image": AgentImage, "audio": AgentAudio}
INSTANCE_TYPE_MAPPING = {str: AgentText, ImageType: AgentImage, torch.Tensor: AgentAudio} INSTANCE_TYPE_MAPPING = {
str: AgentText,
ImageType: AgentImage,
torch.Tensor: AgentAudio,
}
if is_torch_available(): if is_torch_available():
INSTANCE_TYPE_MAPPING[Tensor] = AgentAudio INSTANCE_TYPE_MAPPING[Tensor] = AgentAudio

View File

@ -106,6 +106,7 @@ final_answer("got an error")
```<end_code> ```<end_code>
""" """
def fake_code_model_syntax_error(messages, stop_sequences=None) -> str: def fake_code_model_syntax_error(messages, stop_sequences=None) -> str:
prompt = str(messages) prompt = str(messages)
if "special_marker" not in prompt: if "special_marker" not in prompt:
@ -255,12 +256,13 @@ class AgentTests(unittest.TestCase):
assert "Code execution failed at line 'print = 2' because of" in str(agent.logs) assert "Code execution failed at line 'print = 2' because of" in str(agent.logs)
def test_code_agent_syntax_error_show_offending_lines(self): def test_code_agent_syntax_error_show_offending_lines(self):
agent = CodeAgent(tools=[PythonInterpreterTool()], model=fake_code_model_syntax_error) agent = CodeAgent(
tools=[PythonInterpreterTool()], model=fake_code_model_syntax_error
)
output = agent.run("What is 2 multiplied by 3.6452?") output = agent.run("What is 2 multiplied by 3.6452?")
assert isinstance(output, AgentText) assert isinstance(output, AgentText)
assert output == "got an error" assert output == "got an error"
assert " print(\"Failing due to unexpected indent\")" in str(agent.logs) assert ' print("Failing due to unexpected indent")' in str(agent.logs)
def test_setup_agent_with_empty_toolbox(self): def test_setup_agent_with_empty_toolbox(self):
ToolCallingAgent(model=FakeToolCallModel(), tools=[]) ToolCallingAgent(model=FakeToolCallModel(), tools=[])

View File

@ -16,6 +16,7 @@ import unittest
from smolagents import models, tool from smolagents import models, tool
from typing import Optional from typing import Optional
class ModelTests(unittest.TestCase): class ModelTests(unittest.TestCase):
def test_get_json_schema_has_nullable_args(self): def test_get_json_schema_has_nullable_args(self):
@tool @tool
@ -29,4 +30,10 @@ class ModelTests(unittest.TestCase):
celsius: the temperature type celsius: the temperature type
""" """
return "The weather is UNGODLY with torrential rains and temperatures below -10°C" return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
assert "nullable" in models.get_json_schema(get_weather)["function"]["parameters"]["properties"]["celsius"]
assert (
"nullable"
in models.get_json_schema(get_weather)["function"]["parameters"][
"properties"
]["celsius"]
)

View File

@ -286,18 +286,24 @@ class ToolTests(unittest.TestCase):
def test_tool_missing_class_attributes_raises_error(self): def test_tool_missing_class_attributes_raises_error(self):
with pytest.raises(Exception) as e: with pytest.raises(Exception) as e:
class GetWeatherTool(Tool): class GetWeatherTool(Tool):
name = "get_weather" name = "get_weather"
description = "Get weather in the next days at given location." description = "Get weather in the next days at given location."
inputs = { inputs = {
"location": {"type": "string", "description": "the location"}, "location": {"type": "string", "description": "the location"},
"celsius": {"type": "string", "description": "the temperature type"} "celsius": {
"type": "string",
"description": "the temperature type",
},
} }
def forward(self, location: str, celsius: Optional[bool] = False) -> str: def forward(
self, location: str, celsius: Optional[bool] = False
) -> str:
return "The weather is UNGODLY with torrential rains and temperatures below -10°C" return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
tool = GetWeatherTool() GetWeatherTool()
assert "You must set an attribute output_type" in str(e) assert "You must set an attribute output_type" in str(e)
def test_tool_from_decorator_optional_args(self): def test_tool_from_decorator_optional_args(self):
@ -314,54 +320,69 @@ class ToolTests(unittest.TestCase):
return "The weather is UNGODLY with torrential rains and temperatures below -10°C" return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
assert "nullable" in get_weather.inputs["celsius"] assert "nullable" in get_weather.inputs["celsius"]
assert get_weather.inputs["celsius"]["nullable"] == True assert get_weather.inputs["celsius"]["nullable"]
assert "nullable" not in get_weather.inputs["location"] assert "nullable" not in get_weather.inputs["location"]
def test_tool_mismatching_nullable_args_raises_error(self): def test_tool_mismatching_nullable_args_raises_error(self):
with pytest.raises(Exception) as e: with pytest.raises(Exception) as e:
class GetWeatherTool(Tool): class GetWeatherTool(Tool):
name = "get_weather" name = "get_weather"
description = "Get weather in the next days at given location." description = "Get weather in the next days at given location."
inputs = { inputs = {
"location": {"type": "string", "description": "the location"}, "location": {"type": "string", "description": "the location"},
"celsius": {"type": "string", "description": "the temperature type"} "celsius": {
"type": "string",
"description": "the temperature type",
},
} }
output_type = "string" output_type = "string"
def forward(self, location: str, celsius: Optional[bool] = False) -> str: def forward(
self, location: str, celsius: Optional[bool] = False
) -> str:
return "The weather is UNGODLY with torrential rains and temperatures below -10°C" return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
tool = GetWeatherTool() GetWeatherTool()
assert "Nullable" in str(e) assert "Nullable" in str(e)
with pytest.raises(Exception) as e: with pytest.raises(Exception) as e:
class GetWeatherTool2(Tool): class GetWeatherTool2(Tool):
name = "get_weather" name = "get_weather"
description = "Get weather in the next days at given location." description = "Get weather in the next days at given location."
inputs = { inputs = {
"location": {"type": "string", "description": "the location"}, "location": {"type": "string", "description": "the location"},
"celsius": {"type": "string", "description": "the temperature type"} "celsius": {
"type": "string",
"description": "the temperature type",
},
} }
output_type = "string" output_type = "string"
def forward(self, location: str, celsius: bool = False) -> str: def forward(self, location: str, celsius: bool = False) -> str:
return "The weather is UNGODLY with torrential rains and temperatures below -10°C" return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
tool = GetWeatherTool2() GetWeatherTool2()
assert "Nullable" in str(e) assert "Nullable" in str(e)
with pytest.raises(Exception) as e: with pytest.raises(Exception) as e:
class GetWeatherTool3(Tool): class GetWeatherTool3(Tool):
name = "get_weather" name = "get_weather"
description = "Get weather in the next days at given location." description = "Get weather in the next days at given location."
inputs = { inputs = {
"location": {"type": "string", "description": "the location"}, "location": {"type": "string", "description": "the location"},
"celsius": {"type": "string", "description": "the temperature type", "nullable": True} "celsius": {
"type": "string",
"description": "the temperature type",
"nullable": True,
},
} }
output_type = "string" output_type = "string"
def forward(self, location, celsius: str) -> str: def forward(self, location, celsius: str) -> str:
return "The weather is UNGODLY with torrential rains and temperatures below -10°C" return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
tool = GetWeatherTool3() GetWeatherTool3()
assert "Nullable" in str(e) assert "Nullable" in str(e)