Format
This commit is contained in:
parent
710fb75559
commit
c880f2f5b6
|
@ -122,9 +122,9 @@ def sql_engine(query: str) -> str:
|
|||
|
||||
Now let us create an agent that leverages this tool.
|
||||
|
||||
We use the CodeAgent, which is transformers.agents’ main agent class: an agent that writes actions in code and can iterate on previous output according to the ReAct framework.
|
||||
We use the `CodeAgent`, which is transformers.agents’ main agent class: an agent that writes actions in code and can iterate on previous output according to the ReAct framework.
|
||||
|
||||
The model is the LLM that powers the agent system. HfModel allows you to call LLMs using HF’s Inference API, either via Serverless or Dedicated endpoint, but you could also use any proprietary API.
|
||||
The model is the LLM that powers the agent system. HfApiModel allows you to call LLMs using HF’s Inference API, either via Serverless or Dedicated endpoint, but you could also use any proprietary API.
|
||||
|
||||
```py
|
||||
from smolagents import CodeAgent, HfApiModel
|
||||
|
@ -180,14 +180,14 @@ for table in ["receipts", "waiters"]:
|
|||
|
||||
print(updated_description)
|
||||
```
|
||||
Since this request is a bit harder than the previous one, we’ll switch the LLM engine to use the more powerful [Qwen/Qwen2.5-72B-Instruct](https://huggingface.co/Qwen/Qwen2.5-72B-Instruct)!
|
||||
Since this request is a bit harder than the previous one, we’ll switch the LLM engine to use the more powerful [Qwen/Qwen2.5-Coder-32B-Instruct](https://huggingface.co/Qwen/Qwen2.5-Coder-32B-Instruct)!
|
||||
|
||||
```py
|
||||
sql_engine.description = updated_description
|
||||
|
||||
agent = CodeAgent(
|
||||
tools=[sql_engine],
|
||||
model=HfApiModel("Qwen/Qwen2.5-72B-Instruct"),
|
||||
model=HfApiModel("Qwen/Qwen2.5-Coder-32B-Instruct"),
|
||||
)
|
||||
|
||||
agent.run("Which waiter got more total money from tips?")
|
||||
|
|
|
@ -19,7 +19,6 @@ dependencies = [
|
|||
"pandas>=2.2.3",
|
||||
"jinja2>=3.1.4",
|
||||
"pillow>=11.0.0",
|
||||
"llama-cpp-python>=0.3.4",
|
||||
"markdownify>=0.14.1",
|
||||
"gradio>=5.8.0",
|
||||
"duckduckgo-search>=6.3.7",
|
||||
|
@ -30,9 +29,6 @@ dependencies = [
|
|||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"anthropic",
|
||||
]
|
||||
test = [
|
||||
"gradio-tools"
|
||||
]
|
||||
|
|
|
@ -389,12 +389,16 @@ class MultiStepAgent:
|
|||
|
||||
try:
|
||||
if isinstance(arguments, str):
|
||||
observation = available_tools[tool_name].__call__(arguments, sanitize_inputs_outputs=True)
|
||||
observation = available_tools[tool_name].__call__(
|
||||
arguments, sanitize_inputs_outputs=True
|
||||
)
|
||||
elif isinstance(arguments, dict):
|
||||
for key, value in arguments.items():
|
||||
if isinstance(value, str) and value in self.state:
|
||||
arguments[key] = self.state[value]
|
||||
observation = available_tools[tool_name].__call__(**arguments, sanitize_inputs_outputs=True)
|
||||
observation = available_tools[tool_name].__call__(
|
||||
**arguments, sanitize_inputs_outputs=True
|
||||
)
|
||||
else:
|
||||
error_msg = f"Arguments passed to tool should be a dict or string: got a {type(arguments)}."
|
||||
raise AgentExecutionError(error_msg)
|
||||
|
@ -774,10 +778,14 @@ class ToolCallingAgent(MultiStepAgent):
|
|||
isinstance(answer, str) and answer in self.state.keys()
|
||||
): # if the answer is a state variable, return the value
|
||||
final_answer = self.state[answer]
|
||||
console.print(f"[bold {YELLOW_HEX}]Final answer:[/bold {YELLOW_HEX}] Extracting key '{answer}' from state to return value '{final_answer}'.")
|
||||
console.print(
|
||||
f"[bold {YELLOW_HEX}]Final answer:[/bold {YELLOW_HEX}] Extracting key '{answer}' from state to return value '{final_answer}'."
|
||||
)
|
||||
else:
|
||||
final_answer = answer
|
||||
console.print(Text(f"Final answer: {final_answer}", style=f"bold {YELLOW_HEX}"))
|
||||
console.print(
|
||||
Text(f"Final answer: {final_answer}", style=f"bold {YELLOW_HEX}")
|
||||
)
|
||||
|
||||
log_entry.action_output = final_answer
|
||||
return final_answer
|
||||
|
@ -891,7 +899,12 @@ class CodeAgent(MultiStepAgent):
|
|||
align="left",
|
||||
style="orange",
|
||||
),
|
||||
Syntax(llm_output, lexer="markdown", theme="github-dark", word_wrap=True),
|
||||
Syntax(
|
||||
llm_output,
|
||||
lexer="markdown",
|
||||
theme="github-dark",
|
||||
word_wrap=True,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
@ -163,10 +163,12 @@ class DuckDuckGoSearchTool(Tool):
|
|||
)
|
||||
self.ddgs = DDGS()
|
||||
|
||||
|
||||
def forward(self, query: str) -> str:
|
||||
results = self.ddgs.text(query, max_results=10)
|
||||
postprocessed_results = [f"[{result['title']}]({result['href']})\n{result['body']}" for result in results]
|
||||
postprocessed_results = [
|
||||
f"[{result['title']}]({result['href']})\n{result['body']}"
|
||||
for result in results
|
||||
]
|
||||
return "## Search Results\n\n" + "\n\n".join(postprocessed_results)
|
||||
|
||||
|
||||
|
@ -301,7 +303,12 @@ class SpeechToTextTool(PipelineTool):
|
|||
pre_processor_class = WhisperProcessor
|
||||
model_class = WhisperForConditionalGeneration
|
||||
|
||||
inputs = {"audio": {"type": "audio", "description": "The audio to transcribe. Can be a local path, an url, or a tensor."}}
|
||||
inputs = {
|
||||
"audio": {
|
||||
"type": "audio",
|
||||
"description": "The audio to transcribe. Can be a local path, an url, or a tensor.",
|
||||
}
|
||||
}
|
||||
output_type = "string"
|
||||
|
||||
def encode(self, audio):
|
||||
|
|
|
@ -110,7 +110,6 @@ locals().update({key: value for key, value in pickle_dict.items()})
|
|||
execution_logs = "\n".join([str(log) for log in execution.logs.stdout])
|
||||
console.print(execution_logs)
|
||||
|
||||
|
||||
execution = self.run_code_raise_errors(code_action)
|
||||
execution_logs = "\n".join([str(log) for log in execution.logs.stdout])
|
||||
if not execution.results:
|
||||
|
|
|
@ -129,7 +129,8 @@ def get_clean_message_list(
|
|||
final_message_list.append(message)
|
||||
return final_message_list
|
||||
|
||||
class Model():
|
||||
|
||||
class Model:
|
||||
def __init__(self):
|
||||
self.last_input_token_count = None
|
||||
self.last_output_token_count = None
|
||||
|
@ -313,9 +314,16 @@ class TransformersModel(Model):
|
|||
self.stream = ""
|
||||
|
||||
def __call__(self, input_ids, scores, **kwargs):
|
||||
generated = self.tokenizer.decode(input_ids[0][-1], skip_special_tokens=True)
|
||||
generated = self.tokenizer.decode(
|
||||
input_ids[0][-1], skip_special_tokens=True
|
||||
)
|
||||
self.stream += generated
|
||||
if any([self.stream.endswith(stop_string) for stop_string in self.stop_strings]):
|
||||
if any(
|
||||
[
|
||||
self.stream.endswith(stop_string)
|
||||
for stop_string in self.stop_strings
|
||||
]
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
@ -458,7 +466,7 @@ __all__ = [
|
|||
"MessageRole",
|
||||
"tool_role_conversions",
|
||||
"get_clean_message_list",
|
||||
"HfModel",
|
||||
"Model",
|
||||
"TransformersModel",
|
||||
"HfApiModel",
|
||||
"LiteLLMModel",
|
||||
|
|
|
@ -32,7 +32,7 @@ class Monitor:
|
|||
def get_total_token_counts(self):
|
||||
return {
|
||||
"input": self.total_input_token_count,
|
||||
"output": self.total_output_token_count
|
||||
"output": self.total_output_token_count,
|
||||
}
|
||||
|
||||
def reset(self):
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import ast
|
||||
import inspect
|
||||
import importlib.util
|
||||
import builtins
|
||||
from typing import Set
|
||||
import textwrap
|
||||
|
|
|
@ -108,6 +108,7 @@ def validate_after_init(cls):
|
|||
cls.__init__ = new_init
|
||||
return cls
|
||||
|
||||
|
||||
def _convert_type_hints_to_json_schema(func: Callable) -> Dict:
|
||||
type_hints = get_type_hints(func)
|
||||
signature = inspect.signature(func)
|
||||
|
@ -119,10 +120,13 @@ def _convert_type_hints_to_json_schema(func: Callable) -> Dict:
|
|||
properties[param_name]["nullable"] = True
|
||||
for param_name in signature.parameters.keys():
|
||||
if signature.parameters[param_name].default != inspect.Parameter.empty:
|
||||
if param_name not in properties: # this can happen if the param has no type hint but a default value
|
||||
if (
|
||||
param_name not in properties
|
||||
): # this can happen if the param has no type hint but a default value
|
||||
properties[param_name] = {"nullable": True}
|
||||
return properties
|
||||
|
||||
|
||||
AUTHORIZED_TYPES = [
|
||||
"string",
|
||||
"boolean",
|
||||
|
@ -202,7 +206,10 @@ class Tool:
|
|||
assert getattr(self, "output_type", None) in AUTHORIZED_TYPES
|
||||
|
||||
# Validate forward function signature, except for PipelineTool
|
||||
if not (hasattr(self, "is_pipeline_tool") and getattr(self, "is_pipeline_tool") is True):
|
||||
if not (
|
||||
hasattr(self, "is_pipeline_tool")
|
||||
and getattr(self, "is_pipeline_tool") is True
|
||||
):
|
||||
signature = inspect.signature(self.forward)
|
||||
|
||||
if not set(signature.parameters.keys()) == set(self.inputs.keys()):
|
||||
|
@ -213,9 +220,13 @@ class Tool:
|
|||
json_schema = _convert_type_hints_to_json_schema(self.forward)
|
||||
for key, value in self.inputs.items():
|
||||
if "nullable" in value:
|
||||
assert (key in json_schema and "nullable" in json_schema[key]), f"Nullable argument '{key}' in inputs should have key 'nullable' set to True in function signature."
|
||||
assert (
|
||||
key in json_schema and "nullable" in json_schema[key]
|
||||
), f"Nullable argument '{key}' in inputs should have key 'nullable' set to True in function signature."
|
||||
if key in json_schema and "nullable" in json_schema[key]:
|
||||
assert "nullable" in value, f"Nullable argument '{key}' in function signature should have key 'nullable' set to True in inputs."
|
||||
assert (
|
||||
"nullable" in value
|
||||
), f"Nullable argument '{key}' in function signature should have key 'nullable' set to True in inputs."
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return NotImplementedError("Write this method in your subclass of `Tool`.")
|
||||
|
|
|
@ -249,7 +249,11 @@ class AgentAudio(AgentType, str):
|
|||
|
||||
|
||||
AGENT_TYPE_MAPPING = {"string": AgentText, "image": AgentImage, "audio": AgentAudio}
|
||||
INSTANCE_TYPE_MAPPING = {str: AgentText, ImageType: AgentImage, torch.Tensor: AgentAudio}
|
||||
INSTANCE_TYPE_MAPPING = {
|
||||
str: AgentText,
|
||||
ImageType: AgentImage,
|
||||
torch.Tensor: AgentAudio,
|
||||
}
|
||||
|
||||
if is_torch_available():
|
||||
INSTANCE_TYPE_MAPPING[Tensor] = AgentAudio
|
||||
|
|
|
@ -106,6 +106,7 @@ final_answer("got an error")
|
|||
```<end_code>
|
||||
"""
|
||||
|
||||
|
||||
def fake_code_model_syntax_error(messages, stop_sequences=None) -> str:
|
||||
prompt = str(messages)
|
||||
if "special_marker" not in prompt:
|
||||
|
@ -255,12 +256,13 @@ class AgentTests(unittest.TestCase):
|
|||
assert "Code execution failed at line 'print = 2' because of" in str(agent.logs)
|
||||
|
||||
def test_code_agent_syntax_error_show_offending_lines(self):
|
||||
agent = CodeAgent(tools=[PythonInterpreterTool()], model=fake_code_model_syntax_error)
|
||||
agent = CodeAgent(
|
||||
tools=[PythonInterpreterTool()], model=fake_code_model_syntax_error
|
||||
)
|
||||
output = agent.run("What is 2 multiplied by 3.6452?")
|
||||
assert isinstance(output, AgentText)
|
||||
assert output == "got an error"
|
||||
assert " print(\"Failing due to unexpected indent\")" in str(agent.logs)
|
||||
|
||||
assert ' print("Failing due to unexpected indent")' in str(agent.logs)
|
||||
|
||||
def test_setup_agent_with_empty_toolbox(self):
|
||||
ToolCallingAgent(model=FakeToolCallModel(), tools=[])
|
||||
|
|
|
@ -16,6 +16,7 @@ import unittest
|
|||
from smolagents import models, tool
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class ModelTests(unittest.TestCase):
|
||||
def test_get_json_schema_has_nullable_args(self):
|
||||
@tool
|
||||
|
@ -29,4 +30,10 @@ class ModelTests(unittest.TestCase):
|
|||
celsius: the temperature type
|
||||
"""
|
||||
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
|
||||
assert "nullable" in models.get_json_schema(get_weather)["function"]["parameters"]["properties"]["celsius"]
|
||||
|
||||
assert (
|
||||
"nullable"
|
||||
in models.get_json_schema(get_weather)["function"]["parameters"][
|
||||
"properties"
|
||||
]["celsius"]
|
||||
)
|
||||
|
|
|
@ -286,18 +286,24 @@ class ToolTests(unittest.TestCase):
|
|||
|
||||
def test_tool_missing_class_attributes_raises_error(self):
|
||||
with pytest.raises(Exception) as e:
|
||||
|
||||
class GetWeatherTool(Tool):
|
||||
name = "get_weather"
|
||||
description = "Get weather in the next days at given location."
|
||||
inputs = {
|
||||
"location": {"type": "string", "description": "the location"},
|
||||
"celsius": {"type": "string", "description": "the temperature type"}
|
||||
"celsius": {
|
||||
"type": "string",
|
||||
"description": "the temperature type",
|
||||
},
|
||||
}
|
||||
|
||||
def forward(self, location: str, celsius: Optional[bool] = False) -> str:
|
||||
def forward(
|
||||
self, location: str, celsius: Optional[bool] = False
|
||||
) -> str:
|
||||
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
|
||||
|
||||
tool = GetWeatherTool()
|
||||
|
||||
GetWeatherTool()
|
||||
assert "You must set an attribute output_type" in str(e)
|
||||
|
||||
def test_tool_from_decorator_optional_args(self):
|
||||
|
@ -312,56 +318,71 @@ class ToolTests(unittest.TestCase):
|
|||
celsius: the temperature type
|
||||
"""
|
||||
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
|
||||
|
||||
|
||||
assert "nullable" in get_weather.inputs["celsius"]
|
||||
assert get_weather.inputs["celsius"]["nullable"] == True
|
||||
assert get_weather.inputs["celsius"]["nullable"]
|
||||
assert "nullable" not in get_weather.inputs["location"]
|
||||
|
||||
def test_tool_mismatching_nullable_args_raises_error(self):
|
||||
with pytest.raises(Exception) as e:
|
||||
|
||||
class GetWeatherTool(Tool):
|
||||
name = "get_weather"
|
||||
description = "Get weather in the next days at given location."
|
||||
inputs = {
|
||||
"location": {"type": "string", "description": "the location"},
|
||||
"celsius": {"type": "string", "description": "the temperature type"}
|
||||
"celsius": {
|
||||
"type": "string",
|
||||
"description": "the temperature type",
|
||||
},
|
||||
}
|
||||
output_type = "string"
|
||||
|
||||
def forward(self, location: str, celsius: Optional[bool] = False) -> str:
|
||||
def forward(
|
||||
self, location: str, celsius: Optional[bool] = False
|
||||
) -> str:
|
||||
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
|
||||
|
||||
tool = GetWeatherTool()
|
||||
|
||||
GetWeatherTool()
|
||||
assert "Nullable" in str(e)
|
||||
|
||||
with pytest.raises(Exception) as e:
|
||||
|
||||
class GetWeatherTool2(Tool):
|
||||
name = "get_weather"
|
||||
description = "Get weather in the next days at given location."
|
||||
inputs = {
|
||||
"location": {"type": "string", "description": "the location"},
|
||||
"celsius": {"type": "string", "description": "the temperature type"}
|
||||
"celsius": {
|
||||
"type": "string",
|
||||
"description": "the temperature type",
|
||||
},
|
||||
}
|
||||
output_type = "string"
|
||||
|
||||
def forward(self, location: str, celsius: bool = False) -> str:
|
||||
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
|
||||
|
||||
tool = GetWeatherTool2()
|
||||
|
||||
GetWeatherTool2()
|
||||
assert "Nullable" in str(e)
|
||||
|
||||
with pytest.raises(Exception) as e:
|
||||
|
||||
class GetWeatherTool3(Tool):
|
||||
name = "get_weather"
|
||||
description = "Get weather in the next days at given location."
|
||||
inputs = {
|
||||
"location": {"type": "string", "description": "the location"},
|
||||
"celsius": {"type": "string", "description": "the temperature type", "nullable": True}
|
||||
"celsius": {
|
||||
"type": "string",
|
||||
"description": "the temperature type",
|
||||
"nullable": True,
|
||||
},
|
||||
}
|
||||
output_type = "string"
|
||||
|
||||
def forward(self, location, celsius: str) -> str:
|
||||
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
|
||||
|
||||
tool = GetWeatherTool3()
|
||||
|
||||
GetWeatherTool3()
|
||||
assert "Nullable" in str(e)
|
||||
|
|
Loading…
Reference in New Issue