Formatting

This commit is contained in:
Aymeric 2024-12-23 17:35:34 +01:00
parent 32d7bc5e06
commit cb7e68f2f0
3 changed files with 9 additions and 34 deletions

View File

@ -3,7 +3,7 @@ from agents import tool, HfApiEngine, OpenAIEngine, AnthropicEngine
# Choose which LLM engine to use! # Choose which LLM engine to use!
llm_engine = OpenAIEngine("gpt-4o") llm_engine = OpenAIEngine("gpt-4o")
llm_engine = AnthropicEngine() llm_engine = AnthropicEngine("claude-3-5-sonnet-20240620")
llm_engine = HfApiEngine("meta-llama/Llama-3.3-70B-Instruct") llm_engine = HfApiEngine("meta-llama/Llama-3.3-70B-Instruct")
@tool @tool

View File

@ -23,8 +23,6 @@ from rich.panel import Panel
from rich.rule import Rule from rich.rule import Rule
from rich.text import Text from rich.text import Text
from transformers.utils import is_torch_available
from .utils import console, parse_code_blob, parse_json_tool_call, truncate_content from .utils import console, parse_code_blob, parse_json_tool_call, truncate_content
from .types import AgentAudio, AgentImage from .types import AgentAudio, AgentImage
from .default_tools.base import FinalAnswerTool from .default_tools.base import FinalAnswerTool
@ -217,11 +215,6 @@ class BaseAgent:
if isinstance(tools, Toolbox): if isinstance(tools, Toolbox):
self._toolbox = tools self._toolbox = tools
if add_base_tools: if add_base_tools:
if not is_torch_available():
raise ImportError(
"Using the base tools requires torch to be installed."
)
self._toolbox.add_base_tools( self._toolbox.add_base_tools(
add_python_interpreter=(self.__class__ == JsonAgent) add_python_interpreter=(self.__class__ == JsonAgent)
) )
@ -398,21 +391,17 @@ class MultiStepAgent(BaseAgent):
tools: List[Tool], tools: List[Tool],
llm_engine: Optional[Callable] = None, llm_engine: Optional[Callable] = None,
system_prompt: Optional[str] = None, system_prompt: Optional[str] = None,
tool_description_template: Optional[str] = None,
grammar: Optional[Dict[str, str]] = None, grammar: Optional[Dict[str, str]] = None,
planning_interval: Optional[int] = None, planning_interval: Optional[int] = None,
**kwargs, **kwargs,
): ):
if system_prompt is None: if system_prompt is None:
system_prompt = CODE_SYSTEM_PROMPT system_prompt = CODE_SYSTEM_PROMPT
if tool_description_template is None:
tool_description_template = DEFAULT_TOOL_DESCRIPTION_TEMPLATE
super().__init__( super().__init__(
tools=tools, tools=tools,
llm_engine=llm_engine, llm_engine=llm_engine,
system_prompt=system_prompt, system_prompt=system_prompt,
tool_description_template=tool_description_template,
grammar=grammar, grammar=grammar,
**kwargs, **kwargs,
) )
@ -775,7 +764,6 @@ class JsonAgent(MultiStepAgent):
tools: List[Tool], tools: List[Tool],
llm_engine: Optional[Callable] = None, llm_engine: Optional[Callable] = None,
system_prompt: Optional[str] = None, system_prompt: Optional[str] = None,
tool_description_template: Optional[str] = None,
grammar: Optional[Dict[str, str]] = None, grammar: Optional[Dict[str, str]] = None,
planning_interval: Optional[int] = None, planning_interval: Optional[int] = None,
**kwargs, **kwargs,
@ -784,13 +772,10 @@ class JsonAgent(MultiStepAgent):
llm_engine = HfApiEngine() llm_engine = HfApiEngine()
if system_prompt is None: if system_prompt is None:
system_prompt = JSON_SYSTEM_PROMPT system_prompt = JSON_SYSTEM_PROMPT
if tool_description_template is None:
tool_description_template = DEFAULT_TOOL_DESCRIPTION_TEMPLATE
super().__init__( super().__init__(
tools=tools, tools=tools,
llm_engine=llm_engine, llm_engine=llm_engine,
system_prompt=system_prompt, system_prompt=system_prompt,
tool_description_template=tool_description_template,
grammar=grammar, grammar=grammar,
planning_interval=planning_interval, planning_interval=planning_interval,
**kwargs, **kwargs,
@ -896,7 +881,6 @@ class ToolCallingAgent(MultiStepAgent):
tools: List[Tool], tools: List[Tool],
llm_engine: Optional[Callable] = None, llm_engine: Optional[Callable] = None,
system_prompt: Optional[str] = None, system_prompt: Optional[str] = None,
tool_description_template: Optional[str] = None,
planning_interval: Optional[int] = None, planning_interval: Optional[int] = None,
**kwargs, **kwargs,
): ):
@ -904,13 +888,10 @@ class ToolCallingAgent(MultiStepAgent):
llm_engine = HfApiEngine() llm_engine = HfApiEngine()
if system_prompt is None: if system_prompt is None:
system_prompt = TOOL_CALLING_SYSTEM_PROMPT system_prompt = TOOL_CALLING_SYSTEM_PROMPT
if tool_description_template is None:
tool_description_template = DEFAULT_TOOL_DESCRIPTION_TEMPLATE
super().__init__( super().__init__(
tools=tools, tools=tools,
llm_engine=llm_engine, llm_engine=llm_engine,
system_prompt=system_prompt, system_prompt=system_prompt,
tool_description_template=tool_description_template,
planning_interval=planning_interval, planning_interval=planning_interval,
**kwargs, **kwargs,
) )
@ -986,7 +967,6 @@ class CodeAgent(MultiStepAgent):
tools: List[Tool], tools: List[Tool],
llm_engine: Optional[Callable] = None, llm_engine: Optional[Callable] = None,
system_prompt: Optional[str] = None, system_prompt: Optional[str] = None,
tool_description_template: Optional[str] = None,
grammar: Optional[Dict[str, str]] = None, grammar: Optional[Dict[str, str]] = None,
additional_authorized_imports: Optional[List[str]] = None, additional_authorized_imports: Optional[List[str]] = None,
planning_interval: Optional[int] = None, planning_interval: Optional[int] = None,
@ -997,13 +977,10 @@ class CodeAgent(MultiStepAgent):
llm_engine = HfApiEngine() llm_engine = HfApiEngine()
if system_prompt is None: if system_prompt is None:
system_prompt = CODE_SYSTEM_PROMPT system_prompt = CODE_SYSTEM_PROMPT
if tool_description_template is None:
tool_description_template = DEFAULT_TOOL_DESCRIPTION_TEMPLATE
super().__init__( super().__init__(
tools=tools, tools=tools,
llm_engine=llm_engine, llm_engine=llm_engine,
system_prompt=system_prompt, system_prompt=system_prompt,
tool_description_template=tool_description_template,
grammar=grammar, grammar=grammar,
planning_interval=planning_interval, planning_interval=planning_interval,
**kwargs, **kwargs,
@ -1092,7 +1069,9 @@ class CodeAgent(MultiStepAgent):
raise AgentParsingError(error_msg) raise AgentParsingError(error_msg)
log_entry.tool_call = ToolCall( log_entry.tool_call = ToolCall(
tool_name="python_interpreter", tool_arguments=code_action name="python_interpreter",
arguments=code_action,
id=f"call_{len(self.logs)}",
) )
# Execute # Execute

View File

@ -24,7 +24,6 @@ import textwrap
from functools import lru_cache, wraps from functools import lru_cache, wraps
from pathlib import Path from pathlib import Path
from typing import Callable, Dict, List, Optional, Union from typing import Callable, Dict, List, Optional, Union
from huggingface_hub import ( from huggingface_hub import (
create_repo, create_repo,
get_collection, get_collection,
@ -34,7 +33,7 @@ from huggingface_hub import (
) )
from huggingface_hub.utils import RepositoryNotFoundError from huggingface_hub.utils import RepositoryNotFoundError
from packaging import version from packaging import version
import logging
from transformers.utils import ( from transformers.utils import (
TypeHintParsingException, TypeHintParsingException,
cached_file, cached_file,
@ -43,11 +42,11 @@ from transformers.utils import (
is_torch_available, is_torch_available,
) )
from transformers.dynamic_module_utils import get_imports from transformers.dynamic_module_utils import get_imports
from .types import ImageType, handle_agent_input_types, handle_agent_output_types from .types import ImageType, handle_agent_input_types, handle_agent_output_types
from .utils import instance_to_source from .utils import instance_to_source
from .tool_validation import validate_tool_attributes, MethodChecker from .tool_validation import validate_tool_attributes, MethodChecker
import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -395,6 +394,9 @@ class Tool:
token (`str`, *optional*): token (`str`, *optional*):
The token to identify you on hf.co. If unset, will use the token generated when running The token to identify you on hf.co. If unset, will use the token generated when running
`huggingface-cli login` (stored in `~/.huggingface`). `huggingface-cli login` (stored in `~/.huggingface`).
trust_remote_code(`str`, *optional*, defaults to False):
This flags marks that you understand the risk of running remote code and that you trust this tool.
If not setting this to True, loading the tool from Hub will fail.
kwargs (additional keyword arguments, *optional*): kwargs (additional keyword arguments, *optional*):
Additional keyword arguments that will be split in two: all arguments relevant to the Hub (such as Additional keyword arguments that will be split in two: all arguments relevant to the Hub (such as
`cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the `cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the
@ -802,12 +804,6 @@ def load_tool(
tool_class = getattr(tools_module, tool_class_name) tool_class = getattr(tools_module, tool_class_name)
return tool_class(model_repo_id, token=token, **kwargs) return tool_class(model_repo_id, token=token, **kwargs)
else: else:
logger.warning_once(
f"You're loading a tool from the Hub from {model_repo_id}. Please make sure this is a source that you "
f"trust as the code within that tool will be executed on your machine. Always verify the code of "
f"the tools that you load. We recommend specifying a `revision` to ensure you're loading the "
f"code that you have checked."
)
return Tool.from_hub( return Tool.from_hub(
task_or_repo_id, task_or_repo_id,
model_repo_id=model_repo_id, model_repo_id=model_repo_id,