Add examples to use any LLM as engine

This commit is contained in:
Aymeric 2024-12-13 19:48:55 +01:00
parent 8ed03634b0
commit 9232528232
20 changed files with 202 additions and 65 deletions

View File

@ -261,7 +261,7 @@ For maximum flexibility, you can overwrite the whole system prompt template by p
```python ```python
from transformers import JsonAgent from transformers import JsonAgent
from transformers.agents import PythonInterpreterTool from agents import PythonInterpreterTool
agent = JsonAgent(tools=[PythonInterpreterTool()], system_prompt="{your_custom_prompt}") agent = JsonAgent(tools=[PythonInterpreterTool()], system_prompt="{your_custom_prompt}")
``` ```
@ -381,14 +381,14 @@ Multi-agent has been introduced in Microsoft's framework [Autogen](https://huggi
It simply means having several agents working together to solve your task instead of only one. It simply means having several agents working together to solve your task instead of only one.
It empirically yields better performance on most benchmarks. The reason for this better performance is conceptually simple: for many tasks, rather than using a do-it-all system, you would prefer to specialize units on sub-tasks. Here, having agents with separate tool sets and memories allows to achieve efficient specialization. It empirically yields better performance on most benchmarks. The reason for this better performance is conceptually simple: for many tasks, rather than using a do-it-all system, you would prefer to specialize units on sub-tasks. Here, having agents with separate tool sets and memories allows to achieve efficient specialization.
You can easily build hierarchical multi-agent systems with `transformers.agents`. You can easily build hierarchical multi-agent systems with `agents`.
To do so, encapsulate the agent in a [`ManagedAgent`] object. This object needs arguments `agent`, `name`, and a `description`, which will then be embedded in the manager agent's system prompt to let it know how to call this managed agent, as we also do for tools. To do so, encapsulate the agent in a [`ManagedAgent`] object. This object needs arguments `agent`, `name`, and a `description`, which will then be embedded in the manager agent's system prompt to let it know how to call this managed agent, as we also do for tools.
Here's an example of making an agent that managed a specific web search agent using our [`DuckDuckGoSearchTool`]: Here's an example of making an agent that managed a specific web search agent using our [`DuckDuckGoSearchTool`]:
```py ```py
from transformers.agents import CodeAgent, HfApiEngine, DuckDuckGoSearchTool, ManagedAgent from agents import CodeAgent, HfApiEngine, DuckDuckGoSearchTool, ManagedAgent
llm_engine = HfApiEngine() llm_engine = HfApiEngine()

View File

@ -19,7 +19,7 @@ There's a world of difference between building an agent that works and one that
In this guide, we're going to see best practices for building agents. In this guide, we're going to see best practices for building agents.
> [!TIP] > [!TIP]
> If you're new to `transformers.agents`, make sure to first read the [intro to agents](./intro_agents). > If you're new to `agents`, make sure to first read the [intro to agents](./intro_agents).
### The best agentic systems are the simplest: simplify the workflow as much as you can ### The best agentic systems are the simplest: simplify the workflow as much as you can

View File

@ -156,12 +156,12 @@ These types have three specific purposes:
### AgentText ### AgentText
[[autodoc]] transformers.agents.agent_types.AgentText [[autodoc]] agents.types.AgentText
### AgentImage ### AgentImage
[[autodoc]] transformers.agents.agent_types.AgentImage [[autodoc]] agents.types.AgentImage
### AgentAudio ### AgentAudio
[[autodoc]] transformers.agents.agent_types.AgentAudio [[autodoc]] agents.types.AgentAudio

View File

@ -20,7 +20,7 @@ rendered properly in your Markdown viewer.
Here, we're going to see advanced tool usage. Here, we're going to see advanced tool usage.
> [!TIP] > [!TIP]
> If you're new to `transformers.agents`, make sure to first read the main [agents documentation](./agents). > If you're new to `agents`, make sure to first read the main [agents documentation](./agents).
### Directly define a tool by subclassing Tool, and share it to the Hub ### Directly define a tool by subclassing Tool, and share it to the Hub

View File

@ -0,0 +1,30 @@
from agents import OpenAIEngine, AnthropicEngine, HfApiEngine, CodeAgent
from dotenv import load_dotenv
load_dotenv()
openai_engine = OpenAIEngine(model_name="gpt-4o")
agent = CodeAgent([], llm_engine=openai_engine)
print("\n\n##############")
print("Running OpenAI agent:")
agent.run("What is the 10th Fibonacci Number?")
anthropic_engine = AnthropicEngine()
agent = CodeAgent([], llm_engine=anthropic_engine)
print("\n\n##############")
print("Running Anthropic agent:")
agent.run("What is the 10th Fibonacci Number?")
# Here, our token stored as HF_TOKEN environment variable has accesses 'Make calls to the serverless Inference API' and 'Read access to contents of all public gated repos you can access'
llama_engine = HfApiEngine(model="meta-llama/Llama-3.3-70B-Instruct")
agent = CodeAgent([], llm_engine=llama_engine)
print("\n\n##############")
print("Running Llama3.3-70B agent:")
agent.run("What is the 10th Fibonacci Number?")

View File

@ -22,3 +22,8 @@ dependencies = [
"duckduckgo-search>=6.3.7", "duckduckgo-search>=6.3.7",
"python-dotenv>=1.0.1" "python-dotenv>=1.0.1"
] ]
[project.optional-dependencies]
dev = [
"anthropic",
]

View File

@ -26,9 +26,17 @@ from transformers.utils.import_utils import define_import_structure
if TYPE_CHECKING: if TYPE_CHECKING:
from .agents import * from .agents import *
from .llm_engine import * from .default_tools import *
from .gradio_ui import *
from .llm_engines import *
from .local_python_executor import *
from .monitoring import * from .monitoring import *
from .prompts import *
from .search import *
from .tools import * from .tools import *
from .types import *
from .utils import *
else: else:
import sys import sys

View File

@ -22,9 +22,9 @@ from rich.syntax import Syntax
from transformers.utils import is_torch_available from transformers.utils import is_torch_available
from .utils import console, parse_code_blob, parse_json_tool_call, truncate_content from .utils import console, parse_code_blob, parse_json_tool_call, truncate_content
from .agent_types import AgentAudio, AgentImage from .types import AgentAudio, AgentImage
from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool
from .llm_engine import HfApiEngine, MessageRole from .llm_engines import HfApiEngine, MessageRole
from .monitoring import Monitor from .monitoring import Monitor
from .prompts import ( from .prompts import (
CODE_SYSTEM_PROMPT, CODE_SYSTEM_PROMPT,
@ -492,7 +492,7 @@ class ReactAgent(BaseAgent):
Example: Example:
```py ```py
from transformers.agents import CodeAgent from agents import CodeAgent
agent = CodeAgent(tools=[]) agent = CodeAgent(tools=[])
agent.run("What is the result of 2 power 3.7384?") agent.run("What is the result of 2 power 3.7384?")
``` ```
@ -811,7 +811,7 @@ class JsonAgent(ReactAgent):
) )
log_entry.llm_output = llm_output log_entry.llm_output = llm_output
except Exception as e: except Exception as e:
raise AgentGenerationError(f"Error in generating llm output: {e}.") raise AgentGenerationError(f"Error in generating llm_engine output: {e}.")
if self.verbose: if self.verbose:
console.rule("[italic]Output message of the LLM:") console.rule("[italic]Output message of the LLM:")
@ -944,7 +944,7 @@ class CodeAgent(ReactAgent):
) )
log_entry.llm_output = llm_output log_entry.llm_output = llm_output
except Exception as e: except Exception as e:
raise AgentGenerationError(f"Error in generating llm output: {e}.") raise AgentGenerationError(f"Error in generating llm_engine output: {e}.")
if self.verbose: if self.verbose:
console.rule("[italic]Output message of the LLM:") console.rule("[italic]Output message of the LLM:")
@ -1074,4 +1074,4 @@ And even if your task resolution is not successful, please return as much contex
else: else:
return output return output
__all__ = ["BaseAgent", "ManagedAgent", "ReactAgent", "CodeAgent", "JsonAgent", "Toolbox"] __all__ = ["AgentError", "BaseAgent", "ManagedAgent", "ReactAgent", "CodeAgent", "JsonAgent", "Toolbox"]

View File

@ -14,7 +14,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .agent_types import AgentAudio, AgentImage, AgentText from .types import AgentAudio, AgentImage, AgentText
from .agents import BaseAgent, AgentStep, ActionStep from .agents import BaseAgent, AgentStep, ActionStep
import gradio as gr import gradio as gr

View File

@ -22,10 +22,20 @@ from huggingface_hub import InferenceClient
from transformers import AutoTokenizer, Pipeline from transformers import AutoTokenizer, Pipeline
import logging import logging
import os
from openai import OpenAI
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_JSONAGENT_REGEX_GRAMMAR = {
"type": "regex",
"value": 'Thought: .+?\\nAction:\\n\\{\\n\\s{4}"action":\\s"[^"\\n]+",\\n\\s{4}"action_input":\\s"[^"\\n]+"\\n\\}\\n<end_action>',
}
DEFAULT_CODEAGENT_REGEX_GRAMMAR = {
"type": "regex",
"value": "Thought: .+?\\nCode:\\n```(?:py|python)?\\n(?:.|\\s)+?\\n```<end_action>",
}
class MessageRole(str, Enum): class MessageRole(str, Enum):
USER = "user" USER = "user"
@ -38,6 +48,13 @@ class MessageRole(str, Enum):
def roles(cls): def roles(cls):
return [r.value for r in cls] return [r.value for r in cls]
openai_role_conversions = {
MessageRole.TOOL_RESPONSE: MessageRole.USER,
}
llama_role_conversions = {
MessageRole.TOOL_RESPONSE: MessageRole.USER,
}
def get_clean_message_list( def get_clean_message_list(
message_list: List[Dict[str, str]], role_conversions: Dict[str, str] = {} message_list: List[Dict[str, str]], role_conversions: Dict[str, str] = {}
@ -73,11 +90,6 @@ def get_clean_message_list(
return final_message_list return final_message_list
llama_role_conversions = {
MessageRole.TOOL_RESPONSE: MessageRole.USER,
}
class HfEngine: class HfEngine:
def __init__(self, model_id: Optional[str] = None): def __init__(self, model_id: Optional[str] = None):
self.last_input_token_count = None self.last_input_token_count = None
@ -106,6 +118,7 @@ class HfEngine:
messages: List[Dict[str, str]], messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = None, stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None, grammar: Optional[str] = None,
max_tokens: int = 1500
): ):
raise NotImplementedError raise NotImplementedError
@ -114,6 +127,7 @@ class HfEngine:
messages: List[Dict[str, str]], messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = None, stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None, grammar: Optional[str] = None,
max_tokens: int = 1500,
) -> str: ) -> str:
"""Process the input messages and return the model's response. """Process the input messages and return the model's response.
@ -133,7 +147,7 @@ class HfEngine:
Example: Example:
```python ```python
>>> engine = HfApiEngine( >>> engine = HfApiEngine(
... model="meta-llama/Meta-Llama-3.1-8B-Instruct", ... model="Qwen/Qwen2.5-Coder-32B-Instruct",
... token="your_hf_token_here", ... token="your_hf_token_here",
... max_tokens=2000 ... max_tokens=2000
... ) ... )
@ -149,7 +163,7 @@ class HfEngine:
) )
if stop_sequences is None: if stop_sequences is None:
stop_sequences = [] stop_sequences = []
response = self.generate(messages, stop_sequences, grammar) response = self.generate(messages, stop_sequences, grammar, max_tokens)
self.last_input_token_count = len( self.last_input_token_count = len(
self.tokenizer.apply_chat_template(messages, tokenize=True) self.tokenizer.apply_chat_template(messages, tokenize=True)
) )
@ -168,11 +182,12 @@ class HfApiEngine(HfEngine):
This engine allows you to communicate with Hugging Face's models using the Inference API. It can be used in both serverless mode or with a dedicated endpoint, supporting features like stop sequences and grammar customization. This engine allows you to communicate with Hugging Face's models using the Inference API. It can be used in both serverless mode or with a dedicated endpoint, supporting features like stop sequences and grammar customization.
Parameters: Parameters:
model (`str`, *optional*, defaults to `"meta-llama/Meta-Llama-3.1-8B-Instruct"`): model (`str`, *optional*, defaults to `"Qwen/Qwen2.5-Coder-32B-Instruct"`):
The Hugging Face model ID to be used for inference. This can be a path or model identifier from the Hugging Face model hub. The Hugging Face model ID to be used for inference. This can be a path or model identifier from the Hugging Face model hub.
token (`str`, *optional*): token (`str`, *optional*):
Token used by the Hugging Face API for authentication. Token used by the Hugging Face API for authentication. This token need to be authorized 'Make calls to the serverless Inference API'.
If not provided, the class will use the token stored in the Hugging Face CLI configuration. If the model is gated (like Llama-3 models), the token also needs 'Read access to contents of all public gated repos you can access'.
If not provided, the class will try to use environment variable 'HF_TOKEN', else use the token stored in the Hugging Face CLI configuration.
max_tokens (`int`, *optional*, defaults to 1500): max_tokens (`int`, *optional*, defaults to 1500):
The maximum number of tokens allowed in the output. The maximum number of tokens allowed in the output.
timeout (`int`, *optional*, defaults to 120): timeout (`int`, *optional*, defaults to 120):
@ -185,21 +200,22 @@ class HfApiEngine(HfEngine):
def __init__( def __init__(
self, self,
model: str = "meta-llama/Meta-Llama-3.1-8B-Instruct", model: str = "Qwen/Qwen2.5-Coder-32B-Instruct",
token: Optional[str] = None, token: Optional[str] = None,
max_tokens: Optional[int] = 1500,
timeout: Optional[int] = 120, timeout: Optional[int] = 120,
): ):
super().__init__(model_id=model) super().__init__(model_id=model)
self.model = model self.model = model
if token is None:
token = os.getenv("HF_TOKEN")
self.client = InferenceClient(self.model, token=token, timeout=timeout) self.client = InferenceClient(self.model, token=token, timeout=timeout)
self.max_tokens = max_tokens
def generate( def generate(
self, self,
messages: List[Dict[str, str]], messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = None, stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None, grammar: Optional[str] = None,
max_tokens: int = 1500,
) -> str: ) -> str:
# Get clean message list # Get clean message list
messages = get_clean_message_list( messages = get_clean_message_list(
@ -211,12 +227,12 @@ class HfApiEngine(HfEngine):
response = self.client.chat_completion( response = self.client.chat_completion(
messages, messages,
stop=stop_sequences, stop=stop_sequences,
max_tokens=self.max_tokens,
response_format=grammar, response_format=grammar,
max_tokens=max_tokens,
) )
else: else:
response = self.client.chat_completion( response = self.client.chat_completion(
messages, stop=stop_sequences, max_tokens=self.max_tokens messages, stop=stop_sequences, max_tokens=max_tokens
) )
response = response.choices[0].message.content response = response.choices[0].message.content
@ -235,7 +251,7 @@ class TransformersEngine(HfEngine):
messages: List[Dict[str, str]], messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = None, stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None, grammar: Optional[str] = None,
max_length: int = 1500, max_tokens: int = 1500,
) -> str: ) -> str:
# Get clean message list # Get clean message list
messages = get_clean_message_list( messages = get_clean_message_list(
@ -251,7 +267,7 @@ class TransformersEngine(HfEngine):
output = self.pipeline( output = self.pipeline(
messages, messages,
stop_strings=stop_strings, stop_strings=stop_strings,
max_length=max_length, max_length=max_tokens,
tokenizer=self.pipeline.tokenizer, tokenizer=self.pipeline.tokenizer,
) )
@ -259,14 +275,95 @@ class TransformersEngine(HfEngine):
return response return response
DEFAULT_JSONAGENT_REGEX_GRAMMAR = { class OpenAIEngine:
"type": "regex", def __init__(self, model_name: Optional[str] = None, api_key: Optional[str] = None, base_url: Optional[str] = None):
"value": 'Thought: .+?\\nAction:\\n\\{\\n\\s{4}"action":\\s"[^"\\n]+",\\n\\s{4}"action_input":\\s"[^"\\n]+"\\n\\}\\n<end_action>', """Creates a LLM Engine that follows OpenAI format.
}
DEFAULT_CODEAGENT_REGEX_GRAMMAR = { Args:
"type": "regex", model_name (`str`, *optional*): the model name to use.
"value": "Thought: .+?\\nCode:\\n```(?:py|python)?\\n(?:.|\\s)+?\\n```<end_action>", api_key (`str`, *optional*): your API key.
} base_url (`str`, *optional*): the URL to use if using a different inference service than OpenAI, for instance "https://api-inference.huggingface.co/v1/".
"""
if model_name is None:
model_name = "gpt-4o"
if api_key is None:
api_key = os.getenv("OPENAI_API_KEY")
self.model_name = model_name
self.client = OpenAI(
base_url=base_url,
api_key=api_key,
)
__all__ = ["MessageRole", "llama_role_conversions", "get_clean_message_list", "HfEngine", "TransformersEngine", "HfApiEngine"] def __call__(
self,
messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None,
max_tokens: int = 1500,
) -> str:
messages = get_clean_message_list(messages, role_conversions=openai_role_conversions)
response = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
stop=stop_sequences,
temperature=0.5,
max_tokens=max_tokens,
)
return response.choices[0].message.content
class AnthropicEngine:
def __init__(self, model_name="claude-3-5-sonnet-20240620", use_bedrock=False):
from anthropic import Anthropic, AnthropicBedrock
self.model_name = model_name
if use_bedrock:
self.model_name = "anthropic.claude-3-5-sonnet-20240620-v1:0"
self.client = AnthropicBedrock(
aws_access_key=os.getenv("AWS_BEDROCK_ID"),
aws_secret_key=os.getenv("AWS_BEDROCK_KEY"),
aws_region="us-east-1",
)
else:
self.client = Anthropic(
api_key=os.getenv("ANTHROPIC_API_KEY"),
)
def __call__(
self,
messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None,
max_tokens: int = 1500,
) -> str:
messages = get_clean_message_list(messages, role_conversions=openai_role_conversions)
index_system_message, system_prompt = None, None
for index, message in enumerate(messages):
if message["role"] == MessageRole.SYSTEM:
index_system_message = index
system_prompt = message["content"]
if system_prompt is None:
raise Exception("No system prompt found!")
filtered_messages = [message for i, message in enumerate(messages) if i != index_system_message]
if len(filtered_messages) == 0:
print("Error, no user message:", messages)
assert False
response = self.client.messages.create(
model=self.model_name,
system=system_prompt,
messages=filtered_messages,
stop_sequences=stop_sequences,
temperature=0.5,
max_tokens=max_tokens,
)
full_response_text = ""
for content_block in response.content:
if content_block.type == "text":
full_response_text += content_block.text
return full_response_text
__all__ = ["MessageRole", "llama_role_conversions", "get_clean_message_list", "HfEngine", "TransformersEngine", "HfApiEngine", "OpenAIEngine", "AnthropicEngine"]

View File

@ -231,7 +231,6 @@ def evaluate_class_def(class_def, state, static_tools, custom_tools):
def evaluate_augassign(expression, state, static_tools, custom_tools): def evaluate_augassign(expression, state, static_tools, custom_tools):
# Helper function to get current value and set new value based on the target type
def get_current_value(target): def get_current_value(target):
if isinstance(target, ast.Name): if isinstance(target, ast.Name):
return state.get(target.id, 0) return state.get(target.id, 0)
@ -254,7 +253,6 @@ def evaluate_augassign(expression, state, static_tools, custom_tools):
current_value = get_current_value(expression.target) current_value = get_current_value(expression.target)
value_to_add = evaluate_ast(expression.value, state, static_tools, custom_tools) value_to_add = evaluate_ast(expression.value, state, static_tools, custom_tools)
# Determine the operation and apply it
if isinstance(expression.op, ast.Add): if isinstance(expression.op, ast.Add):
if isinstance(current_value, list): if isinstance(current_value, list):
if not isinstance(value_to_add, list): if not isinstance(value_to_add, list):

View File

@ -51,7 +51,7 @@ from transformers.utils import (
is_torch_available, is_torch_available,
is_vision_available, is_vision_available,
) )
from .agent_types import ImageType, handle_agent_inputs, handle_agent_outputs from .types import ImageType, handle_agent_inputs, handle_agent_outputs
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -928,8 +928,8 @@ def load_tool(task_or_repo_id, model_repo_id=None, token=None, **kwargs):
""" """
if task_or_repo_id in TOOL_MAPPING: if task_or_repo_id in TOOL_MAPPING:
tool_class_name = TOOL_MAPPING[task_or_repo_id] tool_class_name = TOOL_MAPPING[task_or_repo_id]
main_module = importlib.import_module("transformers") main_module = importlib.import_module("agents")
tools_module = main_module.agents tools_module = main_module
tool_class = getattr(tools_module, tool_class_name) tool_class = getattr(tools_module, tool_class_name)
return tool_class(model_repo_id, token=token, **kwargs) return tool_class(model_repo_id, token=token, **kwargs)
else: else:

View File

@ -20,7 +20,7 @@ import pytest
from pathlib import Path from pathlib import Path
from agents.agent_types import AgentText from agents.types import AgentText
from agents.agents import ( from agents.agents import (
AgentMaxIterationsError, AgentMaxIterationsError,
ManagedAgent, ManagedAgent,

View File

@ -20,8 +20,8 @@ import numpy as np
from PIL import Image from PIL import Image
from transformers import is_torch_available from transformers import is_torch_available
from transformers.agents.agent_types import AGENT_TYPE_MAPPING from agents.types import AGENT_TYPE_MAPPING
from transformers.agents.default_tools import FinalAnswerTool from agents.default_tools import FinalAnswerTool
from transformers.testing_utils import get_tests_dir, require_torch from transformers.testing_utils import get_tests_dir, require_torch
from .test_tools_common import ToolTesterMixin from .test_tools_common import ToolTesterMixin

View File

@ -15,9 +15,7 @@
import unittest import unittest
from transformers.agents.agent_types import AgentImage from agents import AgentImage, AgentError, CodeAgent, JsonAgent, stream_to_gradio
from transformers.agents.agents import AgentError, CodeAgent, JsonAgent
from transformers.agents.monitoring import stream_to_gradio
class MonitoringTester(unittest.TestCase): class MonitoringTester(unittest.TestCase):
@ -122,7 +120,7 @@ final_answer('This is the final answer.')
# Use stream_to_gradio to capture the output # Use stream_to_gradio to capture the output
outputs = list(stream_to_gradio(agent, task="Test task", test_mode=True)) outputs = list(stream_to_gradio(agent, task="Test task", test_mode=True))
self.assertEqual(len(outputs), 3) self.assertEqual(len(outputs), 4)
final_message = outputs[-1] final_message = outputs[-1]
self.assertEqual(final_message.role, "assistant") self.assertEqual(final_message.role, "assistant")
self.assertIn("This is the final answer.", final_message.content) self.assertIn("This is the final answer.", final_message.content)
@ -149,7 +147,7 @@ final_answer('This is the final answer.')
) )
) )
self.assertEqual(len(outputs), 2) self.assertEqual(len(outputs), 3)
final_message = outputs[-1] final_message = outputs[-1]
self.assertEqual(final_message.role, "assistant") self.assertEqual(final_message.role, "assistant")
self.assertIsInstance(final_message.content, dict) self.assertIsInstance(final_message.content, dict)
@ -169,7 +167,7 @@ final_answer('This is the final answer.')
# Use stream_to_gradio to capture the output # Use stream_to_gradio to capture the output
outputs = list(stream_to_gradio(agent, task="Test task", test_mode=True)) outputs = list(stream_to_gradio(agent, task="Test task", test_mode=True))
self.assertEqual(len(outputs), 3) self.assertEqual(len(outputs), 5)
final_message = outputs[-1] final_message = outputs[-1]
self.assertEqual(final_message.role, "assistant") self.assertEqual(final_message.role, "assistant")
self.assertIn("Simulated agent error", final_message.content) self.assertIn("Simulated agent error", final_message.content)

View File

@ -18,10 +18,10 @@ import unittest
import numpy as np import numpy as np
import pytest import pytest
from transformers import load_tool from agents import load_tool
from transformers.agents.agent_types import AGENT_TYPE_MAPPING from agents.types import AGENT_TYPE_MAPPING
from transformers.agents.default_tools import BASE_PYTHON_TOOLS from agents.default_tools import BASE_PYTHON_TOOLS
from transformers.agents.python_interpreter import ( from agents.local_python_executor import (
InterpreterError, InterpreterError,
evaluate_python_code, evaluate_python_code,
) )
@ -51,6 +51,7 @@ class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin):
inputs = ["2 * 2"] inputs = ["2 * 2"]
output = self.tool(*inputs) output = self.tool(*inputs)
output_type = AGENT_TYPE_MAPPING[self.tool.output_type] output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
print("OKK", type(output), output_type, AGENT_TYPE_MAPPING)
self.assertTrue(isinstance(output, output_type)) self.assertTrue(isinstance(output, output_type))
def test_agent_types_inputs(self): def test_agent_types_inputs(self):

View File

@ -15,7 +15,7 @@
import unittest import unittest
from transformers import load_tool from agents import load_tool
from .test_tools_common import ToolTesterMixin from .test_tools_common import ToolTesterMixin

View File

@ -20,13 +20,13 @@ import numpy as np
import pytest import pytest
from transformers import is_torch_available, is_vision_available from transformers import is_torch_available, is_vision_available
from transformers.agents.agent_types import ( from agents.types import (
AGENT_TYPE_MAPPING, AGENT_TYPE_MAPPING,
AgentAudio, AgentAudio,
AgentImage, AgentImage,
AgentText, AgentText,
) )
from transformers.agents.tools import Tool, tool from agents.tools import Tool, tool
from transformers.testing_utils import get_tests_dir from transformers.testing_utils import get_tests_dir

View File

@ -18,7 +18,7 @@ import unittest
import uuid import uuid
from pathlib import Path from pathlib import Path
from transformers.agents.agent_types import AgentAudio, AgentImage, AgentText from agents.types import AgentAudio, AgentImage, AgentText
from transformers.testing_utils import ( from transformers.testing_utils import (
get_tests_dir, get_tests_dir,
require_soundfile, require_soundfile,