Add examples to use any LLM as engine
This commit is contained in:
parent
8ed03634b0
commit
9232528232
|
@ -261,7 +261,7 @@ For maximum flexibility, you can overwrite the whole system prompt template by p
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from transformers import JsonAgent
|
from transformers import JsonAgent
|
||||||
from transformers.agents import PythonInterpreterTool
|
from agents import PythonInterpreterTool
|
||||||
|
|
||||||
agent = JsonAgent(tools=[PythonInterpreterTool()], system_prompt="{your_custom_prompt}")
|
agent = JsonAgent(tools=[PythonInterpreterTool()], system_prompt="{your_custom_prompt}")
|
||||||
```
|
```
|
||||||
|
@ -381,14 +381,14 @@ Multi-agent has been introduced in Microsoft's framework [Autogen](https://huggi
|
||||||
It simply means having several agents working together to solve your task instead of only one.
|
It simply means having several agents working together to solve your task instead of only one.
|
||||||
It empirically yields better performance on most benchmarks. The reason for this better performance is conceptually simple: for many tasks, rather than using a do-it-all system, you would prefer to specialize units on sub-tasks. Here, having agents with separate tool sets and memories allows to achieve efficient specialization.
|
It empirically yields better performance on most benchmarks. The reason for this better performance is conceptually simple: for many tasks, rather than using a do-it-all system, you would prefer to specialize units on sub-tasks. Here, having agents with separate tool sets and memories allows to achieve efficient specialization.
|
||||||
|
|
||||||
You can easily build hierarchical multi-agent systems with `transformers.agents`.
|
You can easily build hierarchical multi-agent systems with `agents`.
|
||||||
|
|
||||||
To do so, encapsulate the agent in a [`ManagedAgent`] object. This object needs arguments `agent`, `name`, and a `description`, which will then be embedded in the manager agent's system prompt to let it know how to call this managed agent, as we also do for tools.
|
To do so, encapsulate the agent in a [`ManagedAgent`] object. This object needs arguments `agent`, `name`, and a `description`, which will then be embedded in the manager agent's system prompt to let it know how to call this managed agent, as we also do for tools.
|
||||||
|
|
||||||
Here's an example of making an agent that managed a specific web search agent using our [`DuckDuckGoSearchTool`]:
|
Here's an example of making an agent that managed a specific web search agent using our [`DuckDuckGoSearchTool`]:
|
||||||
|
|
||||||
```py
|
```py
|
||||||
from transformers.agents import CodeAgent, HfApiEngine, DuckDuckGoSearchTool, ManagedAgent
|
from agents import CodeAgent, HfApiEngine, DuckDuckGoSearchTool, ManagedAgent
|
||||||
|
|
||||||
llm_engine = HfApiEngine()
|
llm_engine = HfApiEngine()
|
||||||
|
|
||||||
|
|
|
@ -19,7 +19,7 @@ There's a world of difference between building an agent that works and one that
|
||||||
In this guide, we're going to see best practices for building agents.
|
In this guide, we're going to see best practices for building agents.
|
||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> If you're new to `transformers.agents`, make sure to first read the [intro to agents](./intro_agents).
|
> If you're new to `agents`, make sure to first read the [intro to agents](./intro_agents).
|
||||||
|
|
||||||
### The best agentic systems are the simplest: simplify the workflow as much as you can
|
### The best agentic systems are the simplest: simplify the workflow as much as you can
|
||||||
|
|
||||||
|
|
|
@ -156,12 +156,12 @@ These types have three specific purposes:
|
||||||
|
|
||||||
### AgentText
|
### AgentText
|
||||||
|
|
||||||
[[autodoc]] transformers.agents.agent_types.AgentText
|
[[autodoc]] agents.types.AgentText
|
||||||
|
|
||||||
### AgentImage
|
### AgentImage
|
||||||
|
|
||||||
[[autodoc]] transformers.agents.agent_types.AgentImage
|
[[autodoc]] agents.types.AgentImage
|
||||||
|
|
||||||
### AgentAudio
|
### AgentAudio
|
||||||
|
|
||||||
[[autodoc]] transformers.agents.agent_types.AgentAudio
|
[[autodoc]] agents.types.AgentAudio
|
||||||
|
|
|
@ -20,7 +20,7 @@ rendered properly in your Markdown viewer.
|
||||||
Here, we're going to see advanced tool usage.
|
Here, we're going to see advanced tool usage.
|
||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> If you're new to `transformers.agents`, make sure to first read the main [agents documentation](./agents).
|
> If you're new to `agents`, make sure to first read the main [agents documentation](./agents).
|
||||||
|
|
||||||
|
|
||||||
### Directly define a tool by subclassing Tool, and share it to the Hub
|
### Directly define a tool by subclassing Tool, and share it to the Hub
|
||||||
|
|
|
@ -0,0 +1,30 @@
|
||||||
|
from agents import OpenAIEngine, AnthropicEngine, HfApiEngine, CodeAgent
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
openai_engine = OpenAIEngine(model_name="gpt-4o")
|
||||||
|
|
||||||
|
agent = CodeAgent([], llm_engine=openai_engine)
|
||||||
|
|
||||||
|
print("\n\n##############")
|
||||||
|
print("Running OpenAI agent:")
|
||||||
|
agent.run("What is the 10th Fibonacci Number?")
|
||||||
|
|
||||||
|
|
||||||
|
anthropic_engine = AnthropicEngine()
|
||||||
|
|
||||||
|
agent = CodeAgent([], llm_engine=anthropic_engine)
|
||||||
|
|
||||||
|
print("\n\n##############")
|
||||||
|
print("Running Anthropic agent:")
|
||||||
|
agent.run("What is the 10th Fibonacci Number?")
|
||||||
|
|
||||||
|
# Here, our token stored as HF_TOKEN environment variable has accesses 'Make calls to the serverless Inference API' and 'Read access to contents of all public gated repos you can access'
|
||||||
|
llama_engine = HfApiEngine(model="meta-llama/Llama-3.3-70B-Instruct")
|
||||||
|
|
||||||
|
agent = CodeAgent([], llm_engine=llama_engine)
|
||||||
|
|
||||||
|
print("\n\n##############")
|
||||||
|
print("Running Llama3.3-70B agent:")
|
||||||
|
agent.run("What is the 10th Fibonacci Number?")
|
|
@ -22,3 +22,8 @@ dependencies = [
|
||||||
"duckduckgo-search>=6.3.7",
|
"duckduckgo-search>=6.3.7",
|
||||||
"python-dotenv>=1.0.1"
|
"python-dotenv>=1.0.1"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
dev = [
|
||||||
|
"anthropic",
|
||||||
|
]
|
|
@ -26,9 +26,17 @@ from transformers.utils.import_utils import define_import_structure
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .agents import *
|
from .agents import *
|
||||||
from .llm_engine import *
|
from .default_tools import *
|
||||||
|
from .gradio_ui import *
|
||||||
|
from .llm_engines import *
|
||||||
|
from .local_python_executor import *
|
||||||
from .monitoring import *
|
from .monitoring import *
|
||||||
|
from .prompts import *
|
||||||
|
from .search import *
|
||||||
from .tools import *
|
from .tools import *
|
||||||
|
from .types import *
|
||||||
|
from .utils import *
|
||||||
|
|
||||||
else:
|
else:
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|
|
@ -22,9 +22,9 @@ from rich.syntax import Syntax
|
||||||
from transformers.utils import is_torch_available
|
from transformers.utils import is_torch_available
|
||||||
|
|
||||||
from .utils import console, parse_code_blob, parse_json_tool_call, truncate_content
|
from .utils import console, parse_code_blob, parse_json_tool_call, truncate_content
|
||||||
from .agent_types import AgentAudio, AgentImage
|
from .types import AgentAudio, AgentImage
|
||||||
from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool
|
from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool
|
||||||
from .llm_engine import HfApiEngine, MessageRole
|
from .llm_engines import HfApiEngine, MessageRole
|
||||||
from .monitoring import Monitor
|
from .monitoring import Monitor
|
||||||
from .prompts import (
|
from .prompts import (
|
||||||
CODE_SYSTEM_PROMPT,
|
CODE_SYSTEM_PROMPT,
|
||||||
|
@ -492,7 +492,7 @@ class ReactAgent(BaseAgent):
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
```py
|
```py
|
||||||
from transformers.agents import CodeAgent
|
from agents import CodeAgent
|
||||||
agent = CodeAgent(tools=[])
|
agent = CodeAgent(tools=[])
|
||||||
agent.run("What is the result of 2 power 3.7384?")
|
agent.run("What is the result of 2 power 3.7384?")
|
||||||
```
|
```
|
||||||
|
@ -811,7 +811,7 @@ class JsonAgent(ReactAgent):
|
||||||
)
|
)
|
||||||
log_entry.llm_output = llm_output
|
log_entry.llm_output = llm_output
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise AgentGenerationError(f"Error in generating llm output: {e}.")
|
raise AgentGenerationError(f"Error in generating llm_engine output: {e}.")
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
console.rule("[italic]Output message of the LLM:")
|
console.rule("[italic]Output message of the LLM:")
|
||||||
|
@ -944,7 +944,7 @@ class CodeAgent(ReactAgent):
|
||||||
)
|
)
|
||||||
log_entry.llm_output = llm_output
|
log_entry.llm_output = llm_output
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise AgentGenerationError(f"Error in generating llm output: {e}.")
|
raise AgentGenerationError(f"Error in generating llm_engine output: {e}.")
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
console.rule("[italic]Output message of the LLM:")
|
console.rule("[italic]Output message of the LLM:")
|
||||||
|
@ -1074,4 +1074,4 @@ And even if your task resolution is not successful, please return as much contex
|
||||||
else:
|
else:
|
||||||
return output
|
return output
|
||||||
|
|
||||||
__all__ = ["BaseAgent", "ManagedAgent", "ReactAgent", "CodeAgent", "JsonAgent", "Toolbox"]
|
__all__ = ["AgentError", "BaseAgent", "ManagedAgent", "ReactAgent", "CodeAgent", "JsonAgent", "Toolbox"]
|
|
@ -14,7 +14,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from .agent_types import AgentAudio, AgentImage, AgentText
|
from .types import AgentAudio, AgentImage, AgentText
|
||||||
from .agents import BaseAgent, AgentStep, ActionStep
|
from .agents import BaseAgent, AgentStep, ActionStep
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
|
|
|
@ -22,10 +22,20 @@ from huggingface_hub import InferenceClient
|
||||||
|
|
||||||
from transformers import AutoTokenizer, Pipeline
|
from transformers import AutoTokenizer, Pipeline
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_JSONAGENT_REGEX_GRAMMAR = {
|
||||||
|
"type": "regex",
|
||||||
|
"value": 'Thought: .+?\\nAction:\\n\\{\\n\\s{4}"action":\\s"[^"\\n]+",\\n\\s{4}"action_input":\\s"[^"\\n]+"\\n\\}\\n<end_action>',
|
||||||
|
}
|
||||||
|
|
||||||
|
DEFAULT_CODEAGENT_REGEX_GRAMMAR = {
|
||||||
|
"type": "regex",
|
||||||
|
"value": "Thought: .+?\\nCode:\\n```(?:py|python)?\\n(?:.|\\s)+?\\n```<end_action>",
|
||||||
|
}
|
||||||
|
|
||||||
class MessageRole(str, Enum):
|
class MessageRole(str, Enum):
|
||||||
USER = "user"
|
USER = "user"
|
||||||
|
@ -38,6 +48,13 @@ class MessageRole(str, Enum):
|
||||||
def roles(cls):
|
def roles(cls):
|
||||||
return [r.value for r in cls]
|
return [r.value for r in cls]
|
||||||
|
|
||||||
|
openai_role_conversions = {
|
||||||
|
MessageRole.TOOL_RESPONSE: MessageRole.USER,
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_role_conversions = {
|
||||||
|
MessageRole.TOOL_RESPONSE: MessageRole.USER,
|
||||||
|
}
|
||||||
|
|
||||||
def get_clean_message_list(
|
def get_clean_message_list(
|
||||||
message_list: List[Dict[str, str]], role_conversions: Dict[str, str] = {}
|
message_list: List[Dict[str, str]], role_conversions: Dict[str, str] = {}
|
||||||
|
@ -73,11 +90,6 @@ def get_clean_message_list(
|
||||||
return final_message_list
|
return final_message_list
|
||||||
|
|
||||||
|
|
||||||
llama_role_conversions = {
|
|
||||||
MessageRole.TOOL_RESPONSE: MessageRole.USER,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class HfEngine:
|
class HfEngine:
|
||||||
def __init__(self, model_id: Optional[str] = None):
|
def __init__(self, model_id: Optional[str] = None):
|
||||||
self.last_input_token_count = None
|
self.last_input_token_count = None
|
||||||
|
@ -106,6 +118,7 @@ class HfEngine:
|
||||||
messages: List[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
stop_sequences: Optional[List[str]] = None,
|
stop_sequences: Optional[List[str]] = None,
|
||||||
grammar: Optional[str] = None,
|
grammar: Optional[str] = None,
|
||||||
|
max_tokens: int = 1500
|
||||||
):
|
):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@ -114,6 +127,7 @@ class HfEngine:
|
||||||
messages: List[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
stop_sequences: Optional[List[str]] = None,
|
stop_sequences: Optional[List[str]] = None,
|
||||||
grammar: Optional[str] = None,
|
grammar: Optional[str] = None,
|
||||||
|
max_tokens: int = 1500,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Process the input messages and return the model's response.
|
"""Process the input messages and return the model's response.
|
||||||
|
|
||||||
|
@ -133,7 +147,7 @@ class HfEngine:
|
||||||
Example:
|
Example:
|
||||||
```python
|
```python
|
||||||
>>> engine = HfApiEngine(
|
>>> engine = HfApiEngine(
|
||||||
... model="meta-llama/Meta-Llama-3.1-8B-Instruct",
|
... model="Qwen/Qwen2.5-Coder-32B-Instruct",
|
||||||
... token="your_hf_token_here",
|
... token="your_hf_token_here",
|
||||||
... max_tokens=2000
|
... max_tokens=2000
|
||||||
... )
|
... )
|
||||||
|
@ -149,7 +163,7 @@ class HfEngine:
|
||||||
)
|
)
|
||||||
if stop_sequences is None:
|
if stop_sequences is None:
|
||||||
stop_sequences = []
|
stop_sequences = []
|
||||||
response = self.generate(messages, stop_sequences, grammar)
|
response = self.generate(messages, stop_sequences, grammar, max_tokens)
|
||||||
self.last_input_token_count = len(
|
self.last_input_token_count = len(
|
||||||
self.tokenizer.apply_chat_template(messages, tokenize=True)
|
self.tokenizer.apply_chat_template(messages, tokenize=True)
|
||||||
)
|
)
|
||||||
|
@ -168,11 +182,12 @@ class HfApiEngine(HfEngine):
|
||||||
This engine allows you to communicate with Hugging Face's models using the Inference API. It can be used in both serverless mode or with a dedicated endpoint, supporting features like stop sequences and grammar customization.
|
This engine allows you to communicate with Hugging Face's models using the Inference API. It can be used in both serverless mode or with a dedicated endpoint, supporting features like stop sequences and grammar customization.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
model (`str`, *optional*, defaults to `"meta-llama/Meta-Llama-3.1-8B-Instruct"`):
|
model (`str`, *optional*, defaults to `"Qwen/Qwen2.5-Coder-32B-Instruct"`):
|
||||||
The Hugging Face model ID to be used for inference. This can be a path or model identifier from the Hugging Face model hub.
|
The Hugging Face model ID to be used for inference. This can be a path or model identifier from the Hugging Face model hub.
|
||||||
token (`str`, *optional*):
|
token (`str`, *optional*):
|
||||||
Token used by the Hugging Face API for authentication.
|
Token used by the Hugging Face API for authentication. This token need to be authorized 'Make calls to the serverless Inference API'.
|
||||||
If not provided, the class will use the token stored in the Hugging Face CLI configuration.
|
If the model is gated (like Llama-3 models), the token also needs 'Read access to contents of all public gated repos you can access'.
|
||||||
|
If not provided, the class will try to use environment variable 'HF_TOKEN', else use the token stored in the Hugging Face CLI configuration.
|
||||||
max_tokens (`int`, *optional*, defaults to 1500):
|
max_tokens (`int`, *optional*, defaults to 1500):
|
||||||
The maximum number of tokens allowed in the output.
|
The maximum number of tokens allowed in the output.
|
||||||
timeout (`int`, *optional*, defaults to 120):
|
timeout (`int`, *optional*, defaults to 120):
|
||||||
|
@ -185,21 +200,22 @@ class HfApiEngine(HfEngine):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: str = "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
model: str = "Qwen/Qwen2.5-Coder-32B-Instruct",
|
||||||
token: Optional[str] = None,
|
token: Optional[str] = None,
|
||||||
max_tokens: Optional[int] = 1500,
|
|
||||||
timeout: Optional[int] = 120,
|
timeout: Optional[int] = 120,
|
||||||
):
|
):
|
||||||
super().__init__(model_id=model)
|
super().__init__(model_id=model)
|
||||||
self.model = model
|
self.model = model
|
||||||
|
if token is None:
|
||||||
|
token = os.getenv("HF_TOKEN")
|
||||||
self.client = InferenceClient(self.model, token=token, timeout=timeout)
|
self.client = InferenceClient(self.model, token=token, timeout=timeout)
|
||||||
self.max_tokens = max_tokens
|
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
messages: List[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
stop_sequences: Optional[List[str]] = None,
|
stop_sequences: Optional[List[str]] = None,
|
||||||
grammar: Optional[str] = None,
|
grammar: Optional[str] = None,
|
||||||
|
max_tokens: int = 1500,
|
||||||
) -> str:
|
) -> str:
|
||||||
# Get clean message list
|
# Get clean message list
|
||||||
messages = get_clean_message_list(
|
messages = get_clean_message_list(
|
||||||
|
@ -211,12 +227,12 @@ class HfApiEngine(HfEngine):
|
||||||
response = self.client.chat_completion(
|
response = self.client.chat_completion(
|
||||||
messages,
|
messages,
|
||||||
stop=stop_sequences,
|
stop=stop_sequences,
|
||||||
max_tokens=self.max_tokens,
|
|
||||||
response_format=grammar,
|
response_format=grammar,
|
||||||
|
max_tokens=max_tokens,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
response = self.client.chat_completion(
|
response = self.client.chat_completion(
|
||||||
messages, stop=stop_sequences, max_tokens=self.max_tokens
|
messages, stop=stop_sequences, max_tokens=max_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
response = response.choices[0].message.content
|
response = response.choices[0].message.content
|
||||||
|
@ -235,7 +251,7 @@ class TransformersEngine(HfEngine):
|
||||||
messages: List[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
stop_sequences: Optional[List[str]] = None,
|
stop_sequences: Optional[List[str]] = None,
|
||||||
grammar: Optional[str] = None,
|
grammar: Optional[str] = None,
|
||||||
max_length: int = 1500,
|
max_tokens: int = 1500,
|
||||||
) -> str:
|
) -> str:
|
||||||
# Get clean message list
|
# Get clean message list
|
||||||
messages = get_clean_message_list(
|
messages = get_clean_message_list(
|
||||||
|
@ -251,7 +267,7 @@ class TransformersEngine(HfEngine):
|
||||||
output = self.pipeline(
|
output = self.pipeline(
|
||||||
messages,
|
messages,
|
||||||
stop_strings=stop_strings,
|
stop_strings=stop_strings,
|
||||||
max_length=max_length,
|
max_length=max_tokens,
|
||||||
tokenizer=self.pipeline.tokenizer,
|
tokenizer=self.pipeline.tokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -259,14 +275,95 @@ class TransformersEngine(HfEngine):
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_JSONAGENT_REGEX_GRAMMAR = {
|
class OpenAIEngine:
|
||||||
"type": "regex",
|
def __init__(self, model_name: Optional[str] = None, api_key: Optional[str] = None, base_url: Optional[str] = None):
|
||||||
"value": 'Thought: .+?\\nAction:\\n\\{\\n\\s{4}"action":\\s"[^"\\n]+",\\n\\s{4}"action_input":\\s"[^"\\n]+"\\n\\}\\n<end_action>',
|
"""Creates a LLM Engine that follows OpenAI format.
|
||||||
}
|
|
||||||
|
|
||||||
DEFAULT_CODEAGENT_REGEX_GRAMMAR = {
|
Args:
|
||||||
"type": "regex",
|
model_name (`str`, *optional*): the model name to use.
|
||||||
"value": "Thought: .+?\\nCode:\\n```(?:py|python)?\\n(?:.|\\s)+?\\n```<end_action>",
|
api_key (`str`, *optional*): your API key.
|
||||||
}
|
base_url (`str`, *optional*): the URL to use if using a different inference service than OpenAI, for instance "https://api-inference.huggingface.co/v1/".
|
||||||
|
"""
|
||||||
|
if model_name is None:
|
||||||
|
model_name = "gpt-4o"
|
||||||
|
if api_key is None:
|
||||||
|
api_key = os.getenv("OPENAI_API_KEY")
|
||||||
|
self.model_name = model_name
|
||||||
|
self.client = OpenAI(
|
||||||
|
base_url=base_url,
|
||||||
|
api_key=api_key,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = ["MessageRole", "llama_role_conversions", "get_clean_message_list", "HfEngine", "TransformersEngine", "HfApiEngine"]
|
def __call__(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
|
stop_sequences: Optional[List[str]] = None,
|
||||||
|
grammar: Optional[str] = None,
|
||||||
|
max_tokens: int = 1500,
|
||||||
|
) -> str:
|
||||||
|
messages = get_clean_message_list(messages, role_conversions=openai_role_conversions)
|
||||||
|
|
||||||
|
response = self.client.chat.completions.create(
|
||||||
|
model=self.model_name,
|
||||||
|
messages=messages,
|
||||||
|
stop=stop_sequences,
|
||||||
|
temperature=0.5,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
)
|
||||||
|
return response.choices[0].message.content
|
||||||
|
|
||||||
|
|
||||||
|
class AnthropicEngine:
|
||||||
|
def __init__(self, model_name="claude-3-5-sonnet-20240620", use_bedrock=False):
|
||||||
|
from anthropic import Anthropic, AnthropicBedrock
|
||||||
|
|
||||||
|
self.model_name = model_name
|
||||||
|
if use_bedrock:
|
||||||
|
self.model_name = "anthropic.claude-3-5-sonnet-20240620-v1:0"
|
||||||
|
self.client = AnthropicBedrock(
|
||||||
|
aws_access_key=os.getenv("AWS_BEDROCK_ID"),
|
||||||
|
aws_secret_key=os.getenv("AWS_BEDROCK_KEY"),
|
||||||
|
aws_region="us-east-1",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.client = Anthropic(
|
||||||
|
api_key=os.getenv("ANTHROPIC_API_KEY"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
|
stop_sequences: Optional[List[str]] = None,
|
||||||
|
grammar: Optional[str] = None,
|
||||||
|
max_tokens: int = 1500,
|
||||||
|
) -> str:
|
||||||
|
messages = get_clean_message_list(messages, role_conversions=openai_role_conversions)
|
||||||
|
index_system_message, system_prompt = None, None
|
||||||
|
for index, message in enumerate(messages):
|
||||||
|
if message["role"] == MessageRole.SYSTEM:
|
||||||
|
index_system_message = index
|
||||||
|
system_prompt = message["content"]
|
||||||
|
if system_prompt is None:
|
||||||
|
raise Exception("No system prompt found!")
|
||||||
|
|
||||||
|
filtered_messages = [message for i, message in enumerate(messages) if i != index_system_message]
|
||||||
|
if len(filtered_messages) == 0:
|
||||||
|
print("Error, no user message:", messages)
|
||||||
|
assert False
|
||||||
|
|
||||||
|
response = self.client.messages.create(
|
||||||
|
model=self.model_name,
|
||||||
|
system=system_prompt,
|
||||||
|
messages=filtered_messages,
|
||||||
|
stop_sequences=stop_sequences,
|
||||||
|
temperature=0.5,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
)
|
||||||
|
full_response_text = ""
|
||||||
|
for content_block in response.content:
|
||||||
|
if content_block.type == "text":
|
||||||
|
full_response_text += content_block.text
|
||||||
|
return full_response_text
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["MessageRole", "llama_role_conversions", "get_clean_message_list", "HfEngine", "TransformersEngine", "HfApiEngine", "OpenAIEngine", "AnthropicEngine"]
|
|
@ -231,7 +231,6 @@ def evaluate_class_def(class_def, state, static_tools, custom_tools):
|
||||||
|
|
||||||
|
|
||||||
def evaluate_augassign(expression, state, static_tools, custom_tools):
|
def evaluate_augassign(expression, state, static_tools, custom_tools):
|
||||||
# Helper function to get current value and set new value based on the target type
|
|
||||||
def get_current_value(target):
|
def get_current_value(target):
|
||||||
if isinstance(target, ast.Name):
|
if isinstance(target, ast.Name):
|
||||||
return state.get(target.id, 0)
|
return state.get(target.id, 0)
|
||||||
|
@ -254,7 +253,6 @@ def evaluate_augassign(expression, state, static_tools, custom_tools):
|
||||||
current_value = get_current_value(expression.target)
|
current_value = get_current_value(expression.target)
|
||||||
value_to_add = evaluate_ast(expression.value, state, static_tools, custom_tools)
|
value_to_add = evaluate_ast(expression.value, state, static_tools, custom_tools)
|
||||||
|
|
||||||
# Determine the operation and apply it
|
|
||||||
if isinstance(expression.op, ast.Add):
|
if isinstance(expression.op, ast.Add):
|
||||||
if isinstance(current_value, list):
|
if isinstance(current_value, list):
|
||||||
if not isinstance(value_to_add, list):
|
if not isinstance(value_to_add, list):
|
||||||
|
|
|
@ -51,7 +51,7 @@ from transformers.utils import (
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
is_vision_available,
|
is_vision_available,
|
||||||
)
|
)
|
||||||
from .agent_types import ImageType, handle_agent_inputs, handle_agent_outputs
|
from .types import ImageType, handle_agent_inputs, handle_agent_outputs
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -928,8 +928,8 @@ def load_tool(task_or_repo_id, model_repo_id=None, token=None, **kwargs):
|
||||||
"""
|
"""
|
||||||
if task_or_repo_id in TOOL_MAPPING:
|
if task_or_repo_id in TOOL_MAPPING:
|
||||||
tool_class_name = TOOL_MAPPING[task_or_repo_id]
|
tool_class_name = TOOL_MAPPING[task_or_repo_id]
|
||||||
main_module = importlib.import_module("transformers")
|
main_module = importlib.import_module("agents")
|
||||||
tools_module = main_module.agents
|
tools_module = main_module
|
||||||
tool_class = getattr(tools_module, tool_class_name)
|
tool_class = getattr(tools_module, tool_class_name)
|
||||||
return tool_class(model_repo_id, token=token, **kwargs)
|
return tool_class(model_repo_id, token=token, **kwargs)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -20,7 +20,7 @@ import pytest
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from agents.agent_types import AgentText
|
from agents.types import AgentText
|
||||||
from agents.agents import (
|
from agents.agents import (
|
||||||
AgentMaxIterationsError,
|
AgentMaxIterationsError,
|
||||||
ManagedAgent,
|
ManagedAgent,
|
||||||
|
|
|
@ -20,8 +20,8 @@ import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from transformers import is_torch_available
|
from transformers import is_torch_available
|
||||||
from transformers.agents.agent_types import AGENT_TYPE_MAPPING
|
from agents.types import AGENT_TYPE_MAPPING
|
||||||
from transformers.agents.default_tools import FinalAnswerTool
|
from agents.default_tools import FinalAnswerTool
|
||||||
from transformers.testing_utils import get_tests_dir, require_torch
|
from transformers.testing_utils import get_tests_dir, require_torch
|
||||||
|
|
||||||
from .test_tools_common import ToolTesterMixin
|
from .test_tools_common import ToolTesterMixin
|
||||||
|
|
|
@ -15,9 +15,7 @@
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers.agents.agent_types import AgentImage
|
from agents import AgentImage, AgentError, CodeAgent, JsonAgent, stream_to_gradio
|
||||||
from transformers.agents.agents import AgentError, CodeAgent, JsonAgent
|
|
||||||
from transformers.agents.monitoring import stream_to_gradio
|
|
||||||
|
|
||||||
|
|
||||||
class MonitoringTester(unittest.TestCase):
|
class MonitoringTester(unittest.TestCase):
|
||||||
|
@ -122,7 +120,7 @@ final_answer('This is the final answer.')
|
||||||
# Use stream_to_gradio to capture the output
|
# Use stream_to_gradio to capture the output
|
||||||
outputs = list(stream_to_gradio(agent, task="Test task", test_mode=True))
|
outputs = list(stream_to_gradio(agent, task="Test task", test_mode=True))
|
||||||
|
|
||||||
self.assertEqual(len(outputs), 3)
|
self.assertEqual(len(outputs), 4)
|
||||||
final_message = outputs[-1]
|
final_message = outputs[-1]
|
||||||
self.assertEqual(final_message.role, "assistant")
|
self.assertEqual(final_message.role, "assistant")
|
||||||
self.assertIn("This is the final answer.", final_message.content)
|
self.assertIn("This is the final answer.", final_message.content)
|
||||||
|
@ -149,7 +147,7 @@ final_answer('This is the final answer.')
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(len(outputs), 2)
|
self.assertEqual(len(outputs), 3)
|
||||||
final_message = outputs[-1]
|
final_message = outputs[-1]
|
||||||
self.assertEqual(final_message.role, "assistant")
|
self.assertEqual(final_message.role, "assistant")
|
||||||
self.assertIsInstance(final_message.content, dict)
|
self.assertIsInstance(final_message.content, dict)
|
||||||
|
@ -169,7 +167,7 @@ final_answer('This is the final answer.')
|
||||||
# Use stream_to_gradio to capture the output
|
# Use stream_to_gradio to capture the output
|
||||||
outputs = list(stream_to_gradio(agent, task="Test task", test_mode=True))
|
outputs = list(stream_to_gradio(agent, task="Test task", test_mode=True))
|
||||||
|
|
||||||
self.assertEqual(len(outputs), 3)
|
self.assertEqual(len(outputs), 5)
|
||||||
final_message = outputs[-1]
|
final_message = outputs[-1]
|
||||||
self.assertEqual(final_message.role, "assistant")
|
self.assertEqual(final_message.role, "assistant")
|
||||||
self.assertIn("Simulated agent error", final_message.content)
|
self.assertIn("Simulated agent error", final_message.content)
|
||||||
|
|
|
@ -18,10 +18,10 @@ import unittest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from transformers import load_tool
|
from agents import load_tool
|
||||||
from transformers.agents.agent_types import AGENT_TYPE_MAPPING
|
from agents.types import AGENT_TYPE_MAPPING
|
||||||
from transformers.agents.default_tools import BASE_PYTHON_TOOLS
|
from agents.default_tools import BASE_PYTHON_TOOLS
|
||||||
from transformers.agents.python_interpreter import (
|
from agents.local_python_executor import (
|
||||||
InterpreterError,
|
InterpreterError,
|
||||||
evaluate_python_code,
|
evaluate_python_code,
|
||||||
)
|
)
|
||||||
|
@ -51,6 +51,7 @@ class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin):
|
||||||
inputs = ["2 * 2"]
|
inputs = ["2 * 2"]
|
||||||
output = self.tool(*inputs)
|
output = self.tool(*inputs)
|
||||||
output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
|
output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
|
||||||
|
print("OKK", type(output), output_type, AGENT_TYPE_MAPPING)
|
||||||
self.assertTrue(isinstance(output, output_type))
|
self.assertTrue(isinstance(output, output_type))
|
||||||
|
|
||||||
def test_agent_types_inputs(self):
|
def test_agent_types_inputs(self):
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import load_tool
|
from agents import load_tool
|
||||||
|
|
||||||
from .test_tools_common import ToolTesterMixin
|
from .test_tools_common import ToolTesterMixin
|
||||||
|
|
||||||
|
|
|
@ -20,13 +20,13 @@ import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from transformers import is_torch_available, is_vision_available
|
from transformers import is_torch_available, is_vision_available
|
||||||
from transformers.agents.agent_types import (
|
from agents.types import (
|
||||||
AGENT_TYPE_MAPPING,
|
AGENT_TYPE_MAPPING,
|
||||||
AgentAudio,
|
AgentAudio,
|
||||||
AgentImage,
|
AgentImage,
|
||||||
AgentText,
|
AgentText,
|
||||||
)
|
)
|
||||||
from transformers.agents.tools import Tool, tool
|
from agents.tools import Tool, tool
|
||||||
from transformers.testing_utils import get_tests_dir
|
from transformers.testing_utils import get_tests_dir
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,7 @@ import unittest
|
||||||
import uuid
|
import uuid
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from transformers.agents.agent_types import AgentAudio, AgentImage, AgentText
|
from agents.types import AgentAudio, AgentImage, AgentText
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
get_tests_dir,
|
get_tests_dir,
|
||||||
require_soundfile,
|
require_soundfile,
|
Loading…
Reference in New Issue