Add examples to use any LLM as engine
This commit is contained in:
parent
8ed03634b0
commit
9232528232
|
@ -261,7 +261,7 @@ For maximum flexibility, you can overwrite the whole system prompt template by p
|
|||
|
||||
```python
|
||||
from transformers import JsonAgent
|
||||
from transformers.agents import PythonInterpreterTool
|
||||
from agents import PythonInterpreterTool
|
||||
|
||||
agent = JsonAgent(tools=[PythonInterpreterTool()], system_prompt="{your_custom_prompt}")
|
||||
```
|
||||
|
@ -381,14 +381,14 @@ Multi-agent has been introduced in Microsoft's framework [Autogen](https://huggi
|
|||
It simply means having several agents working together to solve your task instead of only one.
|
||||
It empirically yields better performance on most benchmarks. The reason for this better performance is conceptually simple: for many tasks, rather than using a do-it-all system, you would prefer to specialize units on sub-tasks. Here, having agents with separate tool sets and memories allows to achieve efficient specialization.
|
||||
|
||||
You can easily build hierarchical multi-agent systems with `transformers.agents`.
|
||||
You can easily build hierarchical multi-agent systems with `agents`.
|
||||
|
||||
To do so, encapsulate the agent in a [`ManagedAgent`] object. This object needs arguments `agent`, `name`, and a `description`, which will then be embedded in the manager agent's system prompt to let it know how to call this managed agent, as we also do for tools.
|
||||
|
||||
Here's an example of making an agent that managed a specific web search agent using our [`DuckDuckGoSearchTool`]:
|
||||
|
||||
```py
|
||||
from transformers.agents import CodeAgent, HfApiEngine, DuckDuckGoSearchTool, ManagedAgent
|
||||
from agents import CodeAgent, HfApiEngine, DuckDuckGoSearchTool, ManagedAgent
|
||||
|
||||
llm_engine = HfApiEngine()
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ There's a world of difference between building an agent that works and one that
|
|||
In this guide, we're going to see best practices for building agents.
|
||||
|
||||
> [!TIP]
|
||||
> If you're new to `transformers.agents`, make sure to first read the [intro to agents](./intro_agents).
|
||||
> If you're new to `agents`, make sure to first read the [intro to agents](./intro_agents).
|
||||
|
||||
### The best agentic systems are the simplest: simplify the workflow as much as you can
|
||||
|
||||
|
|
|
@ -156,12 +156,12 @@ These types have three specific purposes:
|
|||
|
||||
### AgentText
|
||||
|
||||
[[autodoc]] transformers.agents.agent_types.AgentText
|
||||
[[autodoc]] agents.types.AgentText
|
||||
|
||||
### AgentImage
|
||||
|
||||
[[autodoc]] transformers.agents.agent_types.AgentImage
|
||||
[[autodoc]] agents.types.AgentImage
|
||||
|
||||
### AgentAudio
|
||||
|
||||
[[autodoc]] transformers.agents.agent_types.AgentAudio
|
||||
[[autodoc]] agents.types.AgentAudio
|
||||
|
|
|
@ -20,7 +20,7 @@ rendered properly in your Markdown viewer.
|
|||
Here, we're going to see advanced tool usage.
|
||||
|
||||
> [!TIP]
|
||||
> If you're new to `transformers.agents`, make sure to first read the main [agents documentation](./agents).
|
||||
> If you're new to `agents`, make sure to first read the main [agents documentation](./agents).
|
||||
|
||||
|
||||
### Directly define a tool by subclassing Tool, and share it to the Hub
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
from agents import OpenAIEngine, AnthropicEngine, HfApiEngine, CodeAgent
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
openai_engine = OpenAIEngine(model_name="gpt-4o")
|
||||
|
||||
agent = CodeAgent([], llm_engine=openai_engine)
|
||||
|
||||
print("\n\n##############")
|
||||
print("Running OpenAI agent:")
|
||||
agent.run("What is the 10th Fibonacci Number?")
|
||||
|
||||
|
||||
anthropic_engine = AnthropicEngine()
|
||||
|
||||
agent = CodeAgent([], llm_engine=anthropic_engine)
|
||||
|
||||
print("\n\n##############")
|
||||
print("Running Anthropic agent:")
|
||||
agent.run("What is the 10th Fibonacci Number?")
|
||||
|
||||
# Here, our token stored as HF_TOKEN environment variable has accesses 'Make calls to the serverless Inference API' and 'Read access to contents of all public gated repos you can access'
|
||||
llama_engine = HfApiEngine(model="meta-llama/Llama-3.3-70B-Instruct")
|
||||
|
||||
agent = CodeAgent([], llm_engine=llama_engine)
|
||||
|
||||
print("\n\n##############")
|
||||
print("Running Llama3.3-70B agent:")
|
||||
agent.run("What is the 10th Fibonacci Number?")
|
|
@ -22,3 +22,8 @@ dependencies = [
|
|||
"duckduckgo-search>=6.3.7",
|
||||
"python-dotenv>=1.0.1"
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"anthropic",
|
||||
]
|
|
@ -26,9 +26,17 @@ from transformers.utils.import_utils import define_import_structure
|
|||
|
||||
if TYPE_CHECKING:
|
||||
from .agents import *
|
||||
from .llm_engine import *
|
||||
from .default_tools import *
|
||||
from .gradio_ui import *
|
||||
from .llm_engines import *
|
||||
from .local_python_executor import *
|
||||
from .monitoring import *
|
||||
from .prompts import *
|
||||
from .search import *
|
||||
from .tools import *
|
||||
from .types import *
|
||||
from .utils import *
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
|
|
@ -22,9 +22,9 @@ from rich.syntax import Syntax
|
|||
from transformers.utils import is_torch_available
|
||||
|
||||
from .utils import console, parse_code_blob, parse_json_tool_call, truncate_content
|
||||
from .agent_types import AgentAudio, AgentImage
|
||||
from .types import AgentAudio, AgentImage
|
||||
from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool
|
||||
from .llm_engine import HfApiEngine, MessageRole
|
||||
from .llm_engines import HfApiEngine, MessageRole
|
||||
from .monitoring import Monitor
|
||||
from .prompts import (
|
||||
CODE_SYSTEM_PROMPT,
|
||||
|
@ -492,7 +492,7 @@ class ReactAgent(BaseAgent):
|
|||
|
||||
Example:
|
||||
```py
|
||||
from transformers.agents import CodeAgent
|
||||
from agents import CodeAgent
|
||||
agent = CodeAgent(tools=[])
|
||||
agent.run("What is the result of 2 power 3.7384?")
|
||||
```
|
||||
|
@ -811,7 +811,7 @@ class JsonAgent(ReactAgent):
|
|||
)
|
||||
log_entry.llm_output = llm_output
|
||||
except Exception as e:
|
||||
raise AgentGenerationError(f"Error in generating llm output: {e}.")
|
||||
raise AgentGenerationError(f"Error in generating llm_engine output: {e}.")
|
||||
|
||||
if self.verbose:
|
||||
console.rule("[italic]Output message of the LLM:")
|
||||
|
@ -944,7 +944,7 @@ class CodeAgent(ReactAgent):
|
|||
)
|
||||
log_entry.llm_output = llm_output
|
||||
except Exception as e:
|
||||
raise AgentGenerationError(f"Error in generating llm output: {e}.")
|
||||
raise AgentGenerationError(f"Error in generating llm_engine output: {e}.")
|
||||
|
||||
if self.verbose:
|
||||
console.rule("[italic]Output message of the LLM:")
|
||||
|
@ -1074,4 +1074,4 @@ And even if your task resolution is not successful, please return as much contex
|
|||
else:
|
||||
return output
|
||||
|
||||
__all__ = ["BaseAgent", "ManagedAgent", "ReactAgent", "CodeAgent", "JsonAgent", "Toolbox"]
|
||||
__all__ = ["AgentError", "BaseAgent", "ManagedAgent", "ReactAgent", "CodeAgent", "JsonAgent", "Toolbox"]
|
|
@ -14,7 +14,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from .agent_types import AgentAudio, AgentImage, AgentText
|
||||
from .types import AgentAudio, AgentImage, AgentText
|
||||
from .agents import BaseAgent, AgentStep, ActionStep
|
||||
import gradio as gr
|
||||
|
||||
|
|
|
@ -22,10 +22,20 @@ from huggingface_hub import InferenceClient
|
|||
|
||||
from transformers import AutoTokenizer, Pipeline
|
||||
import logging
|
||||
|
||||
import os
|
||||
from openai import OpenAI
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_JSONAGENT_REGEX_GRAMMAR = {
|
||||
"type": "regex",
|
||||
"value": 'Thought: .+?\\nAction:\\n\\{\\n\\s{4}"action":\\s"[^"\\n]+",\\n\\s{4}"action_input":\\s"[^"\\n]+"\\n\\}\\n<end_action>',
|
||||
}
|
||||
|
||||
DEFAULT_CODEAGENT_REGEX_GRAMMAR = {
|
||||
"type": "regex",
|
||||
"value": "Thought: .+?\\nCode:\\n```(?:py|python)?\\n(?:.|\\s)+?\\n```<end_action>",
|
||||
}
|
||||
|
||||
class MessageRole(str, Enum):
|
||||
USER = "user"
|
||||
|
@ -38,6 +48,13 @@ class MessageRole(str, Enum):
|
|||
def roles(cls):
|
||||
return [r.value for r in cls]
|
||||
|
||||
openai_role_conversions = {
|
||||
MessageRole.TOOL_RESPONSE: MessageRole.USER,
|
||||
}
|
||||
|
||||
llama_role_conversions = {
|
||||
MessageRole.TOOL_RESPONSE: MessageRole.USER,
|
||||
}
|
||||
|
||||
def get_clean_message_list(
|
||||
message_list: List[Dict[str, str]], role_conversions: Dict[str, str] = {}
|
||||
|
@ -73,11 +90,6 @@ def get_clean_message_list(
|
|||
return final_message_list
|
||||
|
||||
|
||||
llama_role_conversions = {
|
||||
MessageRole.TOOL_RESPONSE: MessageRole.USER,
|
||||
}
|
||||
|
||||
|
||||
class HfEngine:
|
||||
def __init__(self, model_id: Optional[str] = None):
|
||||
self.last_input_token_count = None
|
||||
|
@ -106,6 +118,7 @@ class HfEngine:
|
|||
messages: List[Dict[str, str]],
|
||||
stop_sequences: Optional[List[str]] = None,
|
||||
grammar: Optional[str] = None,
|
||||
max_tokens: int = 1500
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -114,6 +127,7 @@ class HfEngine:
|
|||
messages: List[Dict[str, str]],
|
||||
stop_sequences: Optional[List[str]] = None,
|
||||
grammar: Optional[str] = None,
|
||||
max_tokens: int = 1500,
|
||||
) -> str:
|
||||
"""Process the input messages and return the model's response.
|
||||
|
||||
|
@ -133,7 +147,7 @@ class HfEngine:
|
|||
Example:
|
||||
```python
|
||||
>>> engine = HfApiEngine(
|
||||
... model="meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
... model="Qwen/Qwen2.5-Coder-32B-Instruct",
|
||||
... token="your_hf_token_here",
|
||||
... max_tokens=2000
|
||||
... )
|
||||
|
@ -149,7 +163,7 @@ class HfEngine:
|
|||
)
|
||||
if stop_sequences is None:
|
||||
stop_sequences = []
|
||||
response = self.generate(messages, stop_sequences, grammar)
|
||||
response = self.generate(messages, stop_sequences, grammar, max_tokens)
|
||||
self.last_input_token_count = len(
|
||||
self.tokenizer.apply_chat_template(messages, tokenize=True)
|
||||
)
|
||||
|
@ -168,11 +182,12 @@ class HfApiEngine(HfEngine):
|
|||
This engine allows you to communicate with Hugging Face's models using the Inference API. It can be used in both serverless mode or with a dedicated endpoint, supporting features like stop sequences and grammar customization.
|
||||
|
||||
Parameters:
|
||||
model (`str`, *optional*, defaults to `"meta-llama/Meta-Llama-3.1-8B-Instruct"`):
|
||||
model (`str`, *optional*, defaults to `"Qwen/Qwen2.5-Coder-32B-Instruct"`):
|
||||
The Hugging Face model ID to be used for inference. This can be a path or model identifier from the Hugging Face model hub.
|
||||
token (`str`, *optional*):
|
||||
Token used by the Hugging Face API for authentication.
|
||||
If not provided, the class will use the token stored in the Hugging Face CLI configuration.
|
||||
Token used by the Hugging Face API for authentication. This token need to be authorized 'Make calls to the serverless Inference API'.
|
||||
If the model is gated (like Llama-3 models), the token also needs 'Read access to contents of all public gated repos you can access'.
|
||||
If not provided, the class will try to use environment variable 'HF_TOKEN', else use the token stored in the Hugging Face CLI configuration.
|
||||
max_tokens (`int`, *optional*, defaults to 1500):
|
||||
The maximum number of tokens allowed in the output.
|
||||
timeout (`int`, *optional*, defaults to 120):
|
||||
|
@ -185,21 +200,22 @@ class HfApiEngine(HfEngine):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
model: str = "Qwen/Qwen2.5-Coder-32B-Instruct",
|
||||
token: Optional[str] = None,
|
||||
max_tokens: Optional[int] = 1500,
|
||||
timeout: Optional[int] = 120,
|
||||
):
|
||||
super().__init__(model_id=model)
|
||||
self.model = model
|
||||
if token is None:
|
||||
token = os.getenv("HF_TOKEN")
|
||||
self.client = InferenceClient(self.model, token=token, timeout=timeout)
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
def generate(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
stop_sequences: Optional[List[str]] = None,
|
||||
grammar: Optional[str] = None,
|
||||
max_tokens: int = 1500,
|
||||
) -> str:
|
||||
# Get clean message list
|
||||
messages = get_clean_message_list(
|
||||
|
@ -211,12 +227,12 @@ class HfApiEngine(HfEngine):
|
|||
response = self.client.chat_completion(
|
||||
messages,
|
||||
stop=stop_sequences,
|
||||
max_tokens=self.max_tokens,
|
||||
response_format=grammar,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
else:
|
||||
response = self.client.chat_completion(
|
||||
messages, stop=stop_sequences, max_tokens=self.max_tokens
|
||||
messages, stop=stop_sequences, max_tokens=max_tokens
|
||||
)
|
||||
|
||||
response = response.choices[0].message.content
|
||||
|
@ -235,7 +251,7 @@ class TransformersEngine(HfEngine):
|
|||
messages: List[Dict[str, str]],
|
||||
stop_sequences: Optional[List[str]] = None,
|
||||
grammar: Optional[str] = None,
|
||||
max_length: int = 1500,
|
||||
max_tokens: int = 1500,
|
||||
) -> str:
|
||||
# Get clean message list
|
||||
messages = get_clean_message_list(
|
||||
|
@ -251,7 +267,7 @@ class TransformersEngine(HfEngine):
|
|||
output = self.pipeline(
|
||||
messages,
|
||||
stop_strings=stop_strings,
|
||||
max_length=max_length,
|
||||
max_length=max_tokens,
|
||||
tokenizer=self.pipeline.tokenizer,
|
||||
)
|
||||
|
||||
|
@ -259,14 +275,95 @@ class TransformersEngine(HfEngine):
|
|||
return response
|
||||
|
||||
|
||||
DEFAULT_JSONAGENT_REGEX_GRAMMAR = {
|
||||
"type": "regex",
|
||||
"value": 'Thought: .+?\\nAction:\\n\\{\\n\\s{4}"action":\\s"[^"\\n]+",\\n\\s{4}"action_input":\\s"[^"\\n]+"\\n\\}\\n<end_action>',
|
||||
}
|
||||
class OpenAIEngine:
|
||||
def __init__(self, model_name: Optional[str] = None, api_key: Optional[str] = None, base_url: Optional[str] = None):
|
||||
"""Creates a LLM Engine that follows OpenAI format.
|
||||
|
||||
DEFAULT_CODEAGENT_REGEX_GRAMMAR = {
|
||||
"type": "regex",
|
||||
"value": "Thought: .+?\\nCode:\\n```(?:py|python)?\\n(?:.|\\s)+?\\n```<end_action>",
|
||||
}
|
||||
Args:
|
||||
model_name (`str`, *optional*): the model name to use.
|
||||
api_key (`str`, *optional*): your API key.
|
||||
base_url (`str`, *optional*): the URL to use if using a different inference service than OpenAI, for instance "https://api-inference.huggingface.co/v1/".
|
||||
"""
|
||||
if model_name is None:
|
||||
model_name = "gpt-4o"
|
||||
if api_key is None:
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
self.model_name = model_name
|
||||
self.client = OpenAI(
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
__all__ = ["MessageRole", "llama_role_conversions", "get_clean_message_list", "HfEngine", "TransformersEngine", "HfApiEngine"]
|
||||
def __call__(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
stop_sequences: Optional[List[str]] = None,
|
||||
grammar: Optional[str] = None,
|
||||
max_tokens: int = 1500,
|
||||
) -> str:
|
||||
messages = get_clean_message_list(messages, role_conversions=openai_role_conversions)
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
stop=stop_sequences,
|
||||
temperature=0.5,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
|
||||
|
||||
class AnthropicEngine:
|
||||
def __init__(self, model_name="claude-3-5-sonnet-20240620", use_bedrock=False):
|
||||
from anthropic import Anthropic, AnthropicBedrock
|
||||
|
||||
self.model_name = model_name
|
||||
if use_bedrock:
|
||||
self.model_name = "anthropic.claude-3-5-sonnet-20240620-v1:0"
|
||||
self.client = AnthropicBedrock(
|
||||
aws_access_key=os.getenv("AWS_BEDROCK_ID"),
|
||||
aws_secret_key=os.getenv("AWS_BEDROCK_KEY"),
|
||||
aws_region="us-east-1",
|
||||
)
|
||||
else:
|
||||
self.client = Anthropic(
|
||||
api_key=os.getenv("ANTHROPIC_API_KEY"),
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
stop_sequences: Optional[List[str]] = None,
|
||||
grammar: Optional[str] = None,
|
||||
max_tokens: int = 1500,
|
||||
) -> str:
|
||||
messages = get_clean_message_list(messages, role_conversions=openai_role_conversions)
|
||||
index_system_message, system_prompt = None, None
|
||||
for index, message in enumerate(messages):
|
||||
if message["role"] == MessageRole.SYSTEM:
|
||||
index_system_message = index
|
||||
system_prompt = message["content"]
|
||||
if system_prompt is None:
|
||||
raise Exception("No system prompt found!")
|
||||
|
||||
filtered_messages = [message for i, message in enumerate(messages) if i != index_system_message]
|
||||
if len(filtered_messages) == 0:
|
||||
print("Error, no user message:", messages)
|
||||
assert False
|
||||
|
||||
response = self.client.messages.create(
|
||||
model=self.model_name,
|
||||
system=system_prompt,
|
||||
messages=filtered_messages,
|
||||
stop_sequences=stop_sequences,
|
||||
temperature=0.5,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
full_response_text = ""
|
||||
for content_block in response.content:
|
||||
if content_block.type == "text":
|
||||
full_response_text += content_block.text
|
||||
return full_response_text
|
||||
|
||||
|
||||
__all__ = ["MessageRole", "llama_role_conversions", "get_clean_message_list", "HfEngine", "TransformersEngine", "HfApiEngine", "OpenAIEngine", "AnthropicEngine"]
|
|
@ -231,7 +231,6 @@ def evaluate_class_def(class_def, state, static_tools, custom_tools):
|
|||
|
||||
|
||||
def evaluate_augassign(expression, state, static_tools, custom_tools):
|
||||
# Helper function to get current value and set new value based on the target type
|
||||
def get_current_value(target):
|
||||
if isinstance(target, ast.Name):
|
||||
return state.get(target.id, 0)
|
||||
|
@ -254,7 +253,6 @@ def evaluate_augassign(expression, state, static_tools, custom_tools):
|
|||
current_value = get_current_value(expression.target)
|
||||
value_to_add = evaluate_ast(expression.value, state, static_tools, custom_tools)
|
||||
|
||||
# Determine the operation and apply it
|
||||
if isinstance(expression.op, ast.Add):
|
||||
if isinstance(current_value, list):
|
||||
if not isinstance(value_to_add, list):
|
||||
|
|
|
@ -51,7 +51,7 @@ from transformers.utils import (
|
|||
is_torch_available,
|
||||
is_vision_available,
|
||||
)
|
||||
from .agent_types import ImageType, handle_agent_inputs, handle_agent_outputs
|
||||
from .types import ImageType, handle_agent_inputs, handle_agent_outputs
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -928,8 +928,8 @@ def load_tool(task_or_repo_id, model_repo_id=None, token=None, **kwargs):
|
|||
"""
|
||||
if task_or_repo_id in TOOL_MAPPING:
|
||||
tool_class_name = TOOL_MAPPING[task_or_repo_id]
|
||||
main_module = importlib.import_module("transformers")
|
||||
tools_module = main_module.agents
|
||||
main_module = importlib.import_module("agents")
|
||||
tools_module = main_module
|
||||
tool_class = getattr(tools_module, tool_class_name)
|
||||
return tool_class(model_repo_id, token=token, **kwargs)
|
||||
else:
|
||||
|
|
|
@ -20,7 +20,7 @@ import pytest
|
|||
|
||||
from pathlib import Path
|
||||
|
||||
from agents.agent_types import AgentText
|
||||
from agents.types import AgentText
|
||||
from agents.agents import (
|
||||
AgentMaxIterationsError,
|
||||
ManagedAgent,
|
||||
|
|
|
@ -20,8 +20,8 @@ import numpy as np
|
|||
from PIL import Image
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.agents.agent_types import AGENT_TYPE_MAPPING
|
||||
from transformers.agents.default_tools import FinalAnswerTool
|
||||
from agents.types import AGENT_TYPE_MAPPING
|
||||
from agents.default_tools import FinalAnswerTool
|
||||
from transformers.testing_utils import get_tests_dir, require_torch
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
|
|
@ -15,9 +15,7 @@
|
|||
|
||||
import unittest
|
||||
|
||||
from transformers.agents.agent_types import AgentImage
|
||||
from transformers.agents.agents import AgentError, CodeAgent, JsonAgent
|
||||
from transformers.agents.monitoring import stream_to_gradio
|
||||
from agents import AgentImage, AgentError, CodeAgent, JsonAgent, stream_to_gradio
|
||||
|
||||
|
||||
class MonitoringTester(unittest.TestCase):
|
||||
|
@ -122,7 +120,7 @@ final_answer('This is the final answer.')
|
|||
# Use stream_to_gradio to capture the output
|
||||
outputs = list(stream_to_gradio(agent, task="Test task", test_mode=True))
|
||||
|
||||
self.assertEqual(len(outputs), 3)
|
||||
self.assertEqual(len(outputs), 4)
|
||||
final_message = outputs[-1]
|
||||
self.assertEqual(final_message.role, "assistant")
|
||||
self.assertIn("This is the final answer.", final_message.content)
|
||||
|
@ -149,7 +147,7 @@ final_answer('This is the final answer.')
|
|||
)
|
||||
)
|
||||
|
||||
self.assertEqual(len(outputs), 2)
|
||||
self.assertEqual(len(outputs), 3)
|
||||
final_message = outputs[-1]
|
||||
self.assertEqual(final_message.role, "assistant")
|
||||
self.assertIsInstance(final_message.content, dict)
|
||||
|
@ -169,7 +167,7 @@ final_answer('This is the final answer.')
|
|||
# Use stream_to_gradio to capture the output
|
||||
outputs = list(stream_to_gradio(agent, task="Test task", test_mode=True))
|
||||
|
||||
self.assertEqual(len(outputs), 3)
|
||||
self.assertEqual(len(outputs), 5)
|
||||
final_message = outputs[-1]
|
||||
self.assertEqual(final_message.role, "assistant")
|
||||
self.assertIn("Simulated agent error", final_message.content)
|
||||
|
|
|
@ -18,10 +18,10 @@ import unittest
|
|||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from transformers import load_tool
|
||||
from transformers.agents.agent_types import AGENT_TYPE_MAPPING
|
||||
from transformers.agents.default_tools import BASE_PYTHON_TOOLS
|
||||
from transformers.agents.python_interpreter import (
|
||||
from agents import load_tool
|
||||
from agents.types import AGENT_TYPE_MAPPING
|
||||
from agents.default_tools import BASE_PYTHON_TOOLS
|
||||
from agents.local_python_executor import (
|
||||
InterpreterError,
|
||||
evaluate_python_code,
|
||||
)
|
||||
|
@ -51,6 +51,7 @@ class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin):
|
|||
inputs = ["2 * 2"]
|
||||
output = self.tool(*inputs)
|
||||
output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
|
||||
print("OKK", type(output), output_type, AGENT_TYPE_MAPPING)
|
||||
self.assertTrue(isinstance(output, output_type))
|
||||
|
||||
def test_agent_types_inputs(self):
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
|
||||
import unittest
|
||||
|
||||
from transformers import load_tool
|
||||
from agents import load_tool
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
|
|
@ -20,13 +20,13 @@ import numpy as np
|
|||
import pytest
|
||||
|
||||
from transformers import is_torch_available, is_vision_available
|
||||
from transformers.agents.agent_types import (
|
||||
from agents.types import (
|
||||
AGENT_TYPE_MAPPING,
|
||||
AgentAudio,
|
||||
AgentImage,
|
||||
AgentText,
|
||||
)
|
||||
from transformers.agents.tools import Tool, tool
|
||||
from agents.tools import Tool, tool
|
||||
from transformers.testing_utils import get_tests_dir
|
||||
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ import unittest
|
|||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
from transformers.agents.agent_types import AgentAudio, AgentImage, AgentText
|
||||
from agents.types import AgentAudio, AgentImage, AgentText
|
||||
from transformers.testing_utils import (
|
||||
get_tests_dir,
|
||||
require_soundfile,
|
Loading…
Reference in New Issue