Nicer console outputs with rich
This commit is contained in:
		
							parent
							
								
									f3dcf1f013
								
							
						
					
					
						commit
						146ee3dd32
					
				|  | @ -16,7 +16,7 @@ | ||||||
| # limitations under the License. | # limitations under the License. | ||||||
| from typing import TYPE_CHECKING | from typing import TYPE_CHECKING | ||||||
| 
 | 
 | ||||||
| from ..utils import ( | from transformers.utils import ( | ||||||
|     OptionalDependencyNotAvailable, |     OptionalDependencyNotAvailable, | ||||||
|     _LazyModule, |     _LazyModule, | ||||||
|     is_torch_available, |     is_torch_available, | ||||||
|  |  | ||||||
|  | @ -19,10 +19,11 @@ import uuid | ||||||
| 
 | 
 | ||||||
| import numpy as np | import numpy as np | ||||||
| 
 | 
 | ||||||
| from ..utils import is_soundfile_availble, is_torch_available, is_vision_available, logging | from transformers.utils import is_soundfile_availble, is_torch_available, is_vision_available | ||||||
|  | import logging | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| logger = logging.get_logger(__name__) | logger = logging.getLogger(__name__) | ||||||
| 
 | 
 | ||||||
| if is_vision_available(): | if is_vision_available(): | ||||||
|     from PIL import Image |     from PIL import Image | ||||||
|  | @ -159,7 +160,7 @@ class AgentImage(AgentType, ImageType): | ||||||
| 
 | 
 | ||||||
|             return self._path |             return self._path | ||||||
| 
 | 
 | ||||||
|     def save(self, output_bytes, format, **params): |     def save(self, output_bytes, format = None, **params): | ||||||
|         """ |         """ | ||||||
|         Saves the image to a file. |         Saves the image to a file. | ||||||
|         Args: |         Args: | ||||||
|  | @ -168,7 +169,7 @@ class AgentImage(AgentType, ImageType): | ||||||
|             **params: Additional parameters to pass to PIL.Image.save. |             **params: Additional parameters to pass to PIL.Image.save. | ||||||
|         """ |         """ | ||||||
|         img = self.to_raw() |         img = self.to_raw() | ||||||
|         img.save(output_bytes, format, **params) |         img.save(output_bytes, format=format, **params) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class AgentAudio(AgentType, str): | class AgentAudio(AgentType, str): | ||||||
|  |  | ||||||
							
								
								
									
										233
									
								
								agents/agents.py
								
								
								
								
							
							
						
						
									
										233
									
								
								agents/agents.py
								
								
								
								
							|  | @ -19,10 +19,12 @@ import logging | ||||||
| import re | import re | ||||||
| import time | import time | ||||||
| from typing import Any, Callable, Dict, List, Optional, Tuple, Union | from typing import Any, Callable, Dict, List, Optional, Tuple, Union | ||||||
|  | import rich | ||||||
|  | from rich import markdown as rich_markdown | ||||||
| 
 | 
 | ||||||
| from .. import is_torch_available | from transformers.utils import is_torch_available | ||||||
| from ..utils import logging as transformers_logging | import logging | ||||||
| from ..utils.import_utils import is_pygments_available | from .utils import console | ||||||
| from .agent_types import AgentAudio, AgentImage | from .agent_types import AgentAudio, AgentImage | ||||||
| from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool, setup_default_tools | from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool, setup_default_tools | ||||||
| from .llm_engine import HfApiEngine, MessageRole | from .llm_engine import HfApiEngine, MessageRole | ||||||
|  | @ -48,52 +50,6 @@ from .tools import ( | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| if is_pygments_available(): |  | ||||||
|     from pygments import highlight |  | ||||||
|     from pygments.formatters import Terminal256Formatter |  | ||||||
|     from pygments.lexers import PythonLexer |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| class CustomFormatter(logging.Formatter): |  | ||||||
|     grey = "\x1b[38;20m" |  | ||||||
|     bold_yellow = "\x1b[33;1m" |  | ||||||
|     red = "\x1b[31;20m" |  | ||||||
|     green = "\x1b[32;20m" |  | ||||||
|     bold_green = "\x1b[32;20;1m" |  | ||||||
|     bold_red = "\x1b[31;1m" |  | ||||||
|     bold_white = "\x1b[37;1m" |  | ||||||
|     orange = "\x1b[38;5;214m" |  | ||||||
|     bold_orange = "\x1b[38;5;214;1m" |  | ||||||
|     reset = "\x1b[0m" |  | ||||||
|     format = "%(message)s" |  | ||||||
| 
 |  | ||||||
|     FORMATS = { |  | ||||||
|         logging.DEBUG: grey + format + reset, |  | ||||||
|         logging.INFO: format, |  | ||||||
|         logging.WARNING: bold_yellow + format + reset, |  | ||||||
|         logging.ERROR: red + format + reset, |  | ||||||
|         logging.CRITICAL: bold_red + format + reset, |  | ||||||
|         31: reset + format + reset, |  | ||||||
|         32: green + format + reset, |  | ||||||
|         33: bold_green + format + reset, |  | ||||||
|         34: bold_white + format + reset, |  | ||||||
|         35: orange + format + reset, |  | ||||||
|         36: bold_orange + format + reset, |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     def format(self, record): |  | ||||||
|         log_fmt = self.FORMATS.get(record.levelno) |  | ||||||
|         formatter = logging.Formatter(log_fmt) |  | ||||||
|         return formatter.format(record) |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| logger = transformers_logging.get_logger(__name__) |  | ||||||
| logger.propagate = False |  | ||||||
| ch = logging.StreamHandler() |  | ||||||
| ch.setFormatter(CustomFormatter()) |  | ||||||
| logger.addHandler(ch) |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| def parse_json_blob(json_blob: str) -> Dict[str, str]: | def parse_json_blob(json_blob: str) -> Dict[str, str]: | ||||||
|     try: |     try: | ||||||
|         first_accolade_index = json_blob.find("{") |         first_accolade_index = json_blob.find("{") | ||||||
|  | @ -142,9 +98,10 @@ def parse_json_tool_call(json_blob: str) -> Tuple[str, Dict[str, str]]: | ||||||
|     elif "action" in tool_call: |     elif "action" in tool_call: | ||||||
|         return tool_call["action"], None |         return tool_call["action"], None | ||||||
|     else: |     else: | ||||||
|         raise ValueError( |         missing_keys = [key for key in ['action', 'action_input'] if key not in tool_call] | ||||||
|             f"Missing keys: {[key for key in ['action', 'action_input'] if key not in tool_call]} in blob {tool_call}" |         error_msg = f"Missing keys: {missing_keys} in blob {tool_call}" | ||||||
|         ) |         console.print(f"[bold red]{error_msg}[/bold red]") | ||||||
|  |         raise ValueError(error_msg) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def parse_text_tool_call(text: str) -> Tuple[str, Union[str, Dict[str, str]]]: | def parse_text_tool_call(text: str) -> Tuple[str, Union[str, Dict[str, str]]]: | ||||||
|  | @ -197,18 +154,18 @@ class Toolbox: | ||||||
|         self._tools = {tool.name: tool for tool in tools} |         self._tools = {tool.name: tool for tool in tools} | ||||||
|         if add_base_tools: |         if add_base_tools: | ||||||
|             self.add_base_tools() |             self.add_base_tools() | ||||||
|         self._load_tools_if_needed() |         # self._load_tools_if_needed() | ||||||
| 
 | 
 | ||||||
|     def add_base_tools(self, add_python_interpreter: bool = False): |     def add_base_tools(self, add_python_interpreter: bool = False): | ||||||
|         global _tools_are_initialized |         global _tools_are_initialized | ||||||
|         global HUGGINGFACE_DEFAULT_TOOLS |         global HUGGINGFACE_DEFAULT_TOOLS | ||||||
|         if not _tools_are_initialized: |         if not _tools_are_initialized: | ||||||
|             HUGGINGFACE_DEFAULT_TOOLS = setup_default_tools(logger) |             HUGGINGFACE_DEFAULT_TOOLS = setup_default_tools() | ||||||
|             _tools_are_initialized = True |             _tools_are_initialized = True | ||||||
|         for tool in HUGGINGFACE_DEFAULT_TOOLS.values(): |         for tool in HUGGINGFACE_DEFAULT_TOOLS.values(): | ||||||
|             if tool.name != "python_interpreter" or add_python_interpreter: |             if tool.name != "python_interpreter" or add_python_interpreter: | ||||||
|                 self.add_tool(tool) |                 self.add_tool(tool) | ||||||
|         self._load_tools_if_needed() |         # self._load_tools_if_needed() | ||||||
| 
 | 
 | ||||||
|     @property |     @property | ||||||
|     def tools(self) -> Dict[str, Tool]: |     def tools(self) -> Dict[str, Tool]: | ||||||
|  | @ -271,11 +228,11 @@ class Toolbox: | ||||||
|         """Clears the toolbox""" |         """Clears the toolbox""" | ||||||
|         self._tools = {} |         self._tools = {} | ||||||
| 
 | 
 | ||||||
|     def _load_tools_if_needed(self): |     # def _load_tools_if_needed(self): | ||||||
|         for name, tool in self._tools.items(): |     #     for name, tool in self._tools.items(): | ||||||
|             if not isinstance(tool, Tool): |     #         if not isinstance(tool, Tool): | ||||||
|                 task_or_repo_id = tool.task if tool.repo_id is None else tool.repo_id |     #             task_or_repo_id = tool.task if tool.repo_id is None else tool.repo_id | ||||||
|                 self._tools[name] = load_tool(task_or_repo_id) |     #             self._tools[name] = load_tool(task_or_repo_id) | ||||||
| 
 | 
 | ||||||
|     def __repr__(self): |     def __repr__(self): | ||||||
|         toolbox_description = "Toolbox contents:\n" |         toolbox_description = "Toolbox contents:\n" | ||||||
|  | @ -290,6 +247,8 @@ class AgentError(Exception): | ||||||
|     def __init__(self, message): |     def __init__(self, message): | ||||||
|         super().__init__(message) |         super().__init__(message) | ||||||
|         self.message = message |         self.message = message | ||||||
|  |         console.print(f"[bold red]{message}[/bold red]") | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class AgentParsingError(AgentError): | class AgentParsingError(AgentError): | ||||||
|  | @ -362,7 +321,7 @@ class Agent: | ||||||
|         max_iterations: int = 6, |         max_iterations: int = 6, | ||||||
|         tool_parser: Optional[Callable] = None, |         tool_parser: Optional[Callable] = None, | ||||||
|         add_base_tools: bool = False, |         add_base_tools: bool = False, | ||||||
|         verbose: int = 0, |         verbose: bool = False, | ||||||
|         grammar: Optional[Dict[str, str]] = None, |         grammar: Optional[Dict[str, str]] = None, | ||||||
|         managed_agents: Optional[List] = None, |         managed_agents: Optional[List] = None, | ||||||
|         step_callbacks: Optional[List[Callable]] = None, |         step_callbacks: Optional[List[Callable]] = None, | ||||||
|  | @ -380,7 +339,6 @@ class Agent: | ||||||
|         ) |         ) | ||||||
|         self.additional_args = additional_args |         self.additional_args = additional_args | ||||||
|         self.max_iterations = max_iterations |         self.max_iterations = max_iterations | ||||||
|         self.logger = logger |  | ||||||
|         self.tool_parser = tool_parser |         self.tool_parser = tool_parser | ||||||
|         self.grammar = grammar |         self.grammar = grammar | ||||||
| 
 | 
 | ||||||
|  | @ -406,13 +364,7 @@ class Agent: | ||||||
|         self.prompt = None |         self.prompt = None | ||||||
|         self.logs = [] |         self.logs = [] | ||||||
|         self.task = None |         self.task = None | ||||||
| 
 |         self.verbose = verbose | ||||||
|         if verbose == 0: |  | ||||||
|             logger.setLevel(logging.WARNING) |  | ||||||
|         elif verbose == 1: |  | ||||||
|             logger.setLevel(logging.INFO) |  | ||||||
|         elif verbose == 2: |  | ||||||
|             logger.setLevel(logging.DEBUG) |  | ||||||
| 
 | 
 | ||||||
|         # Initialize step callbacks |         # Initialize step callbacks | ||||||
|         self.step_callbacks = step_callbacks if step_callbacks is not None else [] |         self.step_callbacks = step_callbacks if step_callbacks is not None else [] | ||||||
|  | @ -441,10 +393,8 @@ class Agent: | ||||||
|                 self.system_prompt, list(set(LIST_SAFE_MODULES) | set(self.authorized_imports)) |                 self.system_prompt, list(set(LIST_SAFE_MODULES) | set(self.authorized_imports)) | ||||||
|             ) |             ) | ||||||
|         self.logs = [{"system_prompt": self.system_prompt, "task": self.task}] |         self.logs = [{"system_prompt": self.system_prompt, "task": self.task}] | ||||||
|         self.logger.log(33, "======== New task ========") |         console.rule("New task", characters='=') | ||||||
|         self.logger.log(34, self.task) |         console.print(self.task) | ||||||
|         self.logger.debug("System prompt is as follows:") |  | ||||||
|         self.logger.debug(self.system_prompt) |  | ||||||
| 
 | 
 | ||||||
|     def write_inner_memory_from_logs(self, summary_mode: Optional[bool] = False) -> List[Dict[str, str]]: |     def write_inner_memory_from_logs(self, summary_mode: Optional[bool] = False) -> List[Dict[str, str]]: | ||||||
|         """ |         """ | ||||||
|  | @ -521,7 +471,6 @@ class Agent: | ||||||
|                 split[-1], |                 split[-1], | ||||||
|             )  # NOTE: using indexes starting from the end solves for when you have more than one split_token in the output |             )  # NOTE: using indexes starting from the end solves for when you have more than one split_token in the output | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             self.logger.error(e, exc_info=1) |  | ||||||
|             raise AgentParsingError( |             raise AgentParsingError( | ||||||
|                 f"Error: No '{split_token}' token provided in your output.\nYour output:\n{llm_output}\n. Be sure to include an action, prefaced with '{split_token}'!" |                 f"Error: No '{split_token}' token provided in your output.\nYour output:\n{llm_output}\n. Be sure to include an action, prefaced with '{split_token}'!" | ||||||
|             ) |             ) | ||||||
|  | @ -541,7 +490,7 @@ class Agent: | ||||||
|             available_tools = {**available_tools, **self.managed_agents} |             available_tools = {**available_tools, **self.managed_agents} | ||||||
|         if tool_name not in available_tools: |         if tool_name not in available_tools: | ||||||
|             error_msg = f"Error: unknown tool {tool_name}, should be instead one of {list(available_tools.keys())}." |             error_msg = f"Error: unknown tool {tool_name}, should be instead one of {list(available_tools.keys())}." | ||||||
|             self.logger.error(error_msg, exc_info=1) |             console.print(f"[bold red]{error_msg}") | ||||||
|             raise AgentExecutionError(error_msg) |             raise AgentExecutionError(error_msg) | ||||||
| 
 | 
 | ||||||
|         try: |         try: | ||||||
|  | @ -549,38 +498,39 @@ class Agent: | ||||||
|                 observation = available_tools[tool_name](arguments) |                 observation = available_tools[tool_name](arguments) | ||||||
|             elif isinstance(arguments, dict): |             elif isinstance(arguments, dict): | ||||||
|                 for key, value in arguments.items(): |                 for key, value in arguments.items(): | ||||||
|                     # if the value is the name of a state variable like "image.png", replace it with the actual value |  | ||||||
|                     if isinstance(value, str) and value in self.state: |                     if isinstance(value, str) and value in self.state: | ||||||
|                         arguments[key] = self.state[value] |                         arguments[key] = self.state[value] | ||||||
|                 observation = available_tools[tool_name](**arguments) |                 observation = available_tools[tool_name](**arguments) | ||||||
|             else: |             else: | ||||||
|                 raise AgentExecutionError( |                 error_msg = f"Arguments passed to tool should be a dict or string: got a {type(arguments)}." | ||||||
|                     f"Arguments passed to tool should be a dict or string: got a {type(arguments)}." |                 console.print(f"[bold red]{error_msg}") | ||||||
|                 ) |                 raise AgentExecutionError(error_msg) | ||||||
|             return observation |             return observation | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             if tool_name in self.toolbox.tools: |             if tool_name in self.toolbox.tools: | ||||||
|                 raise AgentExecutionError( |                 tool_description = get_tool_description_with_args(available_tools[tool_name]) | ||||||
|  |                 error_msg = ( | ||||||
|                     f"Error in tool call execution: {e}\nYou should only use this tool with a correct input.\n" |                     f"Error in tool call execution: {e}\nYou should only use this tool with a correct input.\n" | ||||||
|                     f"As a reminder, this tool's description is the following:\n{get_tool_description_with_args(available_tools[tool_name])}" |                     f"As a reminder, this tool's description is the following:\n{tool_description}" | ||||||
|                 ) |                 ) | ||||||
|  |                 console.print(f"[bold red]{error_msg}") | ||||||
|  |                 raise AgentExecutionError(error_msg) | ||||||
|             elif tool_name in self.managed_agents: |             elif tool_name in self.managed_agents: | ||||||
|                 raise AgentExecutionError( |                 error_msg = ( | ||||||
|                     f"Error in calling team member: {e}\nYou should only ask this team member with a correct request.\n" |                     f"Error in calling team member: {e}\nYou should only ask this team member with a correct request.\n" | ||||||
|                     f"As a reminder, this team member's description is the following:\n{available_tools[tool_name]}" |                     f"As a reminder, this team member's description is the following:\n{available_tools[tool_name]}" | ||||||
|                 ) |                 ) | ||||||
|  |                 console.print(f"[bold red]{error_msg}") | ||||||
|  |                 raise AgentExecutionError(error_msg) | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
|     def log_rationale_code_action(self, rationale: str, code_action: str) -> None: |     def log_rationale_code_action(self, rationale: str, code_action: str) -> None: | ||||||
|         self.logger.warning("=== Agent thoughts:") |         if self.verbose: | ||||||
|         self.logger.log(31, rationale) |             console.rule("Agent thoughts") | ||||||
|         self.logger.warning(">>> Agent is executing the code below:") |             console.print(rationale) | ||||||
|         if is_pygments_available(): |         console.rule("Agent is executing the code below:", align="left") | ||||||
|             self.logger.log( |         console.print(code_action) | ||||||
|                 31, highlight(code_action, PythonLexer(ensurenl=False), Terminal256Formatter(style="nord")) |         console.rule("", align="left") | ||||||
|             ) |  | ||||||
|         else: |  | ||||||
|             self.logger.log(31, code_action) |  | ||||||
|         self.logger.warning("====") |  | ||||||
| 
 | 
 | ||||||
|     def run(self, **kwargs): |     def run(self, **kwargs): | ||||||
|         """To be implemented in the child class""" |         """To be implemented in the child class""" | ||||||
|  | @ -617,13 +567,6 @@ class CodeAgent(Agent): | ||||||
|             **kwargs, |             **kwargs, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         if not is_pygments_available(): |  | ||||||
|             transformers_logging.warning_once( |  | ||||||
|                 logger, |  | ||||||
|                 "pygments isn't installed. Installing pygments will enable color syntax highlighting in the " |  | ||||||
|                 "CodeAgent.", |  | ||||||
|             ) |  | ||||||
| 
 |  | ||||||
|         self.python_evaluator = evaluate_python_code |         self.python_evaluator = evaluate_python_code | ||||||
|         self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else [] |         self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else [] | ||||||
|         self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports)) |         self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports)) | ||||||
|  | @ -669,20 +612,21 @@ class CodeAgent(Agent): | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         self.prompt = [prompt_message, task_message] |         self.prompt = [prompt_message, task_message] | ||||||
|         self.logger.info("====Executing with this prompt====") | 
 | ||||||
|         self.logger.info(self.prompt) |         if self.verbose: | ||||||
|  |             console.rule("Executing with this prompt") | ||||||
|  |             console.print(self.prompt) | ||||||
| 
 | 
 | ||||||
|         additional_args = {"grammar": self.grammar} if self.grammar is not None else {} |         additional_args = {"grammar": self.grammar} if self.grammar is not None else {} | ||||||
|         llm_output = self.llm_engine(self.prompt, stop_sequences=["<end_action>"], **additional_args) |         llm_output = self.llm_engine(self.prompt, stop_sequences=["<end_action>"], **additional_args) | ||||||
| 
 | 
 | ||||||
|         if return_generated_code: |  | ||||||
|             return llm_output |  | ||||||
| 
 | 
 | ||||||
|         # Parse |         # Parse | ||||||
|         try: |         try: | ||||||
|             rationale, code_action = self.extract_action(llm_output=llm_output, split_token="Code:") |             rationale, code_action = self.extract_action(llm_output=llm_output, split_token="Code:") | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             self.logger.debug( |             if self.verbose: | ||||||
|  |                 console.print( | ||||||
|                     f"Error in extracting action, trying to parse the whole output as code. Error trace: {e}" |                     f"Error in extracting action, trying to parse the whole output as code. Error trace: {e}" | ||||||
|                 ) |                 ) | ||||||
|             rationale, code_action = "", llm_output |             rationale, code_action = "", llm_output | ||||||
|  | @ -691,7 +635,7 @@ class CodeAgent(Agent): | ||||||
|             code_action = self.parse_code_blob(code_action) |             code_action = self.parse_code_blob(code_action) | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             error_msg = f"Error in code parsing: {e}. Be sure to provide correct code" |             error_msg = f"Error in code parsing: {e}. Be sure to provide correct code" | ||||||
|             self.logger.error(error_msg, exc_info=1) |             console.print(f"[bold red]{error_msg}[/bold red]") | ||||||
|             return error_msg |             return error_msg | ||||||
| 
 | 
 | ||||||
|         # Execute |         # Execute | ||||||
|  | @ -705,11 +649,12 @@ class CodeAgent(Agent): | ||||||
|                 state=self.state, |                 state=self.state, | ||||||
|                 authorized_imports=self.authorized_imports, |                 authorized_imports=self.authorized_imports, | ||||||
|             ) |             ) | ||||||
|             self.logger.info(self.state["print_outputs"]) |             if self.verbose: | ||||||
|  |                 console.print(self.state["print_outputs"]) | ||||||
|             return output |             return output | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             error_msg = f"Error in execution: {e}. Be sure to provide correct code." |             error_msg = f"Error in execution: {e}. Be sure to provide correct code." | ||||||
|             self.logger.error(error_msg, exc_info=1) |             console.print(f"[bold red]{error_msg}[/bold red]") | ||||||
|             return error_msg |             return error_msg | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -759,7 +704,7 @@ class ReactAgent(Agent): | ||||||
|         self.prompt = [ |         self.prompt = [ | ||||||
|             { |             { | ||||||
|                 "role": MessageRole.SYSTEM, |                 "role": MessageRole.SYSTEM, | ||||||
|                 "content": "An agent tried to answer an user query but it got stuck and failed to do so. You are tasked with providing an answer instead. Here is the agent's memory:", |                 "content": "An agent tried to answer a user query but it got stuck and failed to do so. You are tasked with providing an answer instead. Here is the agent's memory:", | ||||||
|             } |             } | ||||||
|         ] |         ] | ||||||
|         self.prompt += self.write_inner_memory_from_logs()[1:] |         self.prompt += self.write_inner_memory_from_logs()[1:] | ||||||
|  | @ -772,7 +717,9 @@ class ReactAgent(Agent): | ||||||
|         try: |         try: | ||||||
|             return self.llm_engine(self.prompt) |             return self.llm_engine(self.prompt) | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             return f"Error in generating final llm output: {e}." |             error_msg = f"Error in generating final LLM output: {e}." | ||||||
|  |             console.print(f"[bold red]{error_msg}[/bold red]") | ||||||
|  |             return error_msg | ||||||
| 
 | 
 | ||||||
|     def run(self, task: str, stream: bool = False, reset: bool = True, **kwargs): |     def run(self, task: str, stream: bool = False, reset: bool = True, **kwargs): | ||||||
|         """ |         """ | ||||||
|  | @ -815,7 +762,6 @@ class ReactAgent(Agent): | ||||||
|                 if "final_answer" in step_log_entry: |                 if "final_answer" in step_log_entry: | ||||||
|                     final_answer = step_log_entry["final_answer"] |                     final_answer = step_log_entry["final_answer"] | ||||||
|             except AgentError as e: |             except AgentError as e: | ||||||
|                 self.logger.error(e, exc_info=1) |  | ||||||
|                 step_log_entry["error"] = e |                 step_log_entry["error"] = e | ||||||
|             finally: |             finally: | ||||||
|                 step_end_time = time.time() |                 step_end_time = time.time() | ||||||
|  | @ -831,7 +777,7 @@ class ReactAgent(Agent): | ||||||
|             error_message = "Reached max iterations." |             error_message = "Reached max iterations." | ||||||
|             final_step_log = {"error": AgentMaxIterationsError(error_message)} |             final_step_log = {"error": AgentMaxIterationsError(error_message)} | ||||||
|             self.logs.append(final_step_log) |             self.logs.append(final_step_log) | ||||||
|             self.logger.error(error_message, exc_info=1) |             console.print(f"[bold red]{error_message}") | ||||||
|             final_answer = self.provide_final_answer(task) |             final_answer = self.provide_final_answer(task) | ||||||
|             final_step_log["final_answer"] = final_answer |             final_step_log["final_answer"] = final_answer | ||||||
|             final_step_log["step_duration"] = 0 |             final_step_log["step_duration"] = 0 | ||||||
|  | @ -857,7 +803,6 @@ class ReactAgent(Agent): | ||||||
|                 if "final_answer" in step_log_entry: |                 if "final_answer" in step_log_entry: | ||||||
|                     final_answer = step_log_entry["final_answer"] |                     final_answer = step_log_entry["final_answer"] | ||||||
|             except AgentError as e: |             except AgentError as e: | ||||||
|                 self.logger.error(e, exc_info=1) |  | ||||||
|                 step_log_entry["error"] = e |                 step_log_entry["error"] = e | ||||||
|             finally: |             finally: | ||||||
|                 step_end_time = time.time() |                 step_end_time = time.time() | ||||||
|  | @ -872,7 +817,7 @@ class ReactAgent(Agent): | ||||||
|             error_message = "Reached max iterations." |             error_message = "Reached max iterations." | ||||||
|             final_step_log = {"error": AgentMaxIterationsError(error_message)} |             final_step_log = {"error": AgentMaxIterationsError(error_message)} | ||||||
|             self.logs.append(final_step_log) |             self.logs.append(final_step_log) | ||||||
|             self.logger.error(error_message, exc_info=1) |             console.print(f"[bold red]{error_message}") | ||||||
|             final_answer = self.provide_final_answer(task) |             final_answer = self.provide_final_answer(task) | ||||||
|             final_step_log["final_answer"] = final_answer |             final_step_log["final_answer"] = final_answer | ||||||
|             final_step_log["step_duration"] = 0 |             final_step_log["step_duration"] = 0 | ||||||
|  | @ -931,8 +876,8 @@ Now begin!""", | ||||||
| {answer_facts} | {answer_facts} | ||||||
| ```""".strip() | ```""".strip() | ||||||
|             self.logs.append({"plan": final_plan_redaction, "facts": final_facts_redaction}) |             self.logs.append({"plan": final_plan_redaction, "facts": final_facts_redaction}) | ||||||
|             self.logger.log(36, "===== Initial plan =====") |             console.rule("[orange]Initial plan") | ||||||
|             self.logger.log(35, final_plan_redaction) |             console.print(final_plan_redaction) | ||||||
|         else:  # update plan |         else:  # update plan | ||||||
|             agent_memory = self.write_inner_memory_from_logs( |             agent_memory = self.write_inner_memory_from_logs( | ||||||
|                 summary_mode=False |                 summary_mode=False | ||||||
|  | @ -977,8 +922,8 @@ Now begin!""", | ||||||
| {facts_update} | {facts_update} | ||||||
| ```""" | ```""" | ||||||
|             self.logs.append({"plan": final_plan_redaction, "facts": final_facts_redaction}) |             self.logs.append({"plan": final_plan_redaction, "facts": final_facts_redaction}) | ||||||
|             self.logger.log(36, "===== Updated plan =====") |             console.rule("Updated plan") | ||||||
|             self.logger.log(35, final_plan_redaction) |             console.print(final_plan_redaction) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class ReactJsonAgent(ReactAgent): | class ReactJsonAgent(ReactAgent): | ||||||
|  | @ -1022,13 +967,14 @@ class ReactJsonAgent(ReactAgent): | ||||||
|         agent_memory = self.write_inner_memory_from_logs() |         agent_memory = self.write_inner_memory_from_logs() | ||||||
| 
 | 
 | ||||||
|         self.prompt = agent_memory |         self.prompt = agent_memory | ||||||
|         self.logger.debug("===== New step =====") |         console.rule("New step") | ||||||
| 
 | 
 | ||||||
|         # Add new step in logs |         # Add new step in logs | ||||||
|         log_entry["agent_memory"] = agent_memory.copy() |         log_entry["agent_memory"] = agent_memory.copy() | ||||||
| 
 | 
 | ||||||
|         self.logger.info("===== Calling LLM with this last message: =====") |         if self.verbose: | ||||||
|         self.logger.info(self.prompt[-1]) |             console.rule("Calling LLM with this last message:") | ||||||
|  |             console.print(self.prompt[-1]) | ||||||
| 
 | 
 | ||||||
|         try: |         try: | ||||||
|             additional_args = {"grammar": self.grammar} if self.grammar is not None else {} |             additional_args = {"grammar": self.grammar} if self.grammar is not None else {} | ||||||
|  | @ -1037,12 +983,12 @@ class ReactJsonAgent(ReactAgent): | ||||||
|             ) |             ) | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             raise AgentGenerationError(f"Error in generating llm output: {e}.") |             raise AgentGenerationError(f"Error in generating llm output: {e}.") | ||||||
|         self.logger.debug("===== Output message of the LLM: =====") |         console.rule("===== Output message of the LLM: =====") | ||||||
|         self.logger.debug(llm_output) |         console.print(llm_output) | ||||||
|         log_entry["llm_output"] = llm_output |         log_entry["llm_output"] = llm_output | ||||||
| 
 | 
 | ||||||
|         # Parse |         # Parse | ||||||
|         self.logger.debug("===== Extracting action =====") |         console.rule("===== Extracting action =====") | ||||||
|         rationale, action = self.extract_action(llm_output=llm_output, split_token="Action:") |         rationale, action = self.extract_action(llm_output=llm_output, split_token="Action:") | ||||||
| 
 | 
 | ||||||
|         try: |         try: | ||||||
|  | @ -1054,9 +1000,9 @@ class ReactJsonAgent(ReactAgent): | ||||||
|         log_entry["tool_call"] = {"tool_name": tool_name, "tool_arguments": arguments} |         log_entry["tool_call"] = {"tool_name": tool_name, "tool_arguments": arguments} | ||||||
| 
 | 
 | ||||||
|         # Execute |         # Execute | ||||||
|         self.logger.warning("=== Agent thoughts:") |         console.print("=== Agent thoughts:") | ||||||
|         self.logger.log(31, rationale) |         console.print(rationale) | ||||||
|         self.logger.warning(f">>> Calling tool: '{tool_name}' with arguments: {arguments}") |         console.print(f">>> Calling tool: '{tool_name}' with arguments: {arguments}") | ||||||
|         if tool_name == "final_answer": |         if tool_name == "final_answer": | ||||||
|             if isinstance(arguments, dict): |             if isinstance(arguments, dict): | ||||||
|                 if "answer" in arguments: |                 if "answer" in arguments: | ||||||
|  | @ -1087,7 +1033,6 @@ class ReactJsonAgent(ReactAgent): | ||||||
|                 updated_information = f"Stored '{observation_name}' in memory." |                 updated_information = f"Stored '{observation_name}' in memory." | ||||||
|             else: |             else: | ||||||
|                 updated_information = str(observation).strip() |                 updated_information = str(observation).strip() | ||||||
|             self.logger.info(updated_information) |  | ||||||
|             log_entry["observation"] = updated_information |             log_entry["observation"] = updated_information | ||||||
|             return log_entry |             return log_entry | ||||||
| 
 | 
 | ||||||
|  | @ -1126,13 +1071,6 @@ class ReactCodeAgent(ReactAgent): | ||||||
|             **kwargs, |             **kwargs, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         if not is_pygments_available(): |  | ||||||
|             transformers_logging.warning_once( |  | ||||||
|                 logger, |  | ||||||
|                 "pygments isn't installed. Installing pygments will enable color syntax highlighting in the " |  | ||||||
|                 "ReactCodeAgent.", |  | ||||||
|             ) |  | ||||||
| 
 |  | ||||||
|         self.python_evaluator = evaluate_python_code |         self.python_evaluator = evaluate_python_code | ||||||
|         self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else [] |         self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else [] | ||||||
|         self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports)) |         self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports)) | ||||||
|  | @ -1147,13 +1085,14 @@ class ReactCodeAgent(ReactAgent): | ||||||
|         agent_memory = self.write_inner_memory_from_logs() |         agent_memory = self.write_inner_memory_from_logs() | ||||||
| 
 | 
 | ||||||
|         self.prompt = agent_memory.copy() |         self.prompt = agent_memory.copy() | ||||||
|         self.logger.debug("===== New step =====") |         console.rule("New step") | ||||||
| 
 | 
 | ||||||
|         # Add new step in logs |         # Add new step in logs | ||||||
|         log_entry["agent_memory"] = agent_memory.copy() |         log_entry["agent_memory"] = agent_memory.copy() | ||||||
| 
 | 
 | ||||||
|         self.logger.info("===== Calling LLM with these last messages: =====") |         if self.verbose: | ||||||
|         self.logger.info(self.prompt[-2:]) |             console.print("===== Calling LLM with these last messages: =====") | ||||||
|  |             console.print(self.prompt[-2:]) | ||||||
| 
 | 
 | ||||||
|         try: |         try: | ||||||
|             additional_args = {"grammar": self.grammar} if self.grammar is not None else {} |             additional_args = {"grammar": self.grammar} if self.grammar is not None else {} | ||||||
|  | @ -1163,16 +1102,16 @@ class ReactCodeAgent(ReactAgent): | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             raise AgentGenerationError(f"Error in generating llm output: {e}.") |             raise AgentGenerationError(f"Error in generating llm output: {e}.") | ||||||
| 
 | 
 | ||||||
|         self.logger.debug("=== Output message of the LLM:") |         if self.verbose: | ||||||
|         self.logger.debug(llm_output) |             console.rule("Output message of the LLM:") | ||||||
|  |             console.print(llm_output) | ||||||
|         log_entry["llm_output"] = llm_output |         log_entry["llm_output"] = llm_output | ||||||
| 
 | 
 | ||||||
|         # Parse |         # Parse | ||||||
|         self.logger.debug("=== Extracting action ===") |  | ||||||
|         try: |         try: | ||||||
|             rationale, raw_code_action = self.extract_action(llm_output=llm_output, split_token="Code:") |             rationale, raw_code_action = self.extract_action(llm_output=llm_output, split_token="Code:") | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             self.logger.debug(f"Error in extracting action, trying to parse the whole output. Error trace: {e}") |             console.print(f"Error in extracting action, trying to parse the whole output. Error trace: {e}") | ||||||
|             rationale, raw_code_action = llm_output, llm_output |             rationale, raw_code_action = llm_output, llm_output | ||||||
| 
 | 
 | ||||||
|         try: |         try: | ||||||
|  | @ -1200,12 +1139,12 @@ class ReactCodeAgent(ReactAgent): | ||||||
|                 state=self.state, |                 state=self.state, | ||||||
|                 authorized_imports=self.authorized_imports, |                 authorized_imports=self.authorized_imports, | ||||||
|             ) |             ) | ||||||
|             self.logger.warning("Print outputs:") |             console.print("Print outputs:") | ||||||
|             self.logger.log(32, self.state["print_outputs"]) |             console.print(self.state["print_outputs"]) | ||||||
|             observation = "Print outputs:\n" + self.state["print_outputs"] |             observation = "Print outputs:\n" + self.state["print_outputs"] | ||||||
|             if result is not None: |             if result is not None: | ||||||
|                 self.logger.warning("Last output from code snippet:") |                 console.print("Last output from code snippet:") | ||||||
|                 self.logger.log(32, str(result)) |                 console.print(str(result)) | ||||||
|                 observation += "Last output from code snippet:\n" + str(result)[:100000] |                 observation += "Last output from code snippet:\n" + str(result)[:100000] | ||||||
|             log_entry["observation"] = observation |             log_entry["observation"] = observation | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|  | @ -1215,8 +1154,8 @@ class ReactCodeAgent(ReactAgent): | ||||||
|             raise AgentExecutionError(error_msg) |             raise AgentExecutionError(error_msg) | ||||||
|         for line in code_action.split("\n"): |         for line in code_action.split("\n"): | ||||||
|             if line[: len("final_answer")] == "final_answer": |             if line[: len("final_answer")] == "final_answer": | ||||||
|                 self.logger.log(33, "Final answer:") |                 console.print("Final answer:") | ||||||
|                 self.logger.log(32, result) |                 console.print(f"[bold]{result}") | ||||||
|                 log_entry["final_answer"] = result |                 log_entry["final_answer"] = result | ||||||
|         return result |         return result | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -23,7 +23,7 @@ from typing import Dict | ||||||
| 
 | 
 | ||||||
| from huggingface_hub import hf_hub_download, list_spaces | from huggingface_hub import hf_hub_download, list_spaces | ||||||
| 
 | 
 | ||||||
| from ..utils import is_offline_mode | from transformers.utils import is_offline_mode | ||||||
| from .python_interpreter import LIST_SAFE_MODULES, evaluate_python_code | from .python_interpreter import LIST_SAFE_MODULES, evaluate_python_code | ||||||
| from .tools import TOOL_CONFIG_FILE, TOOL_MAPPING, Tool | from .tools import TOOL_CONFIG_FILE, TOOL_MAPPING, Tool | ||||||
| 
 | 
 | ||||||
|  | @ -128,7 +128,7 @@ def get_remote_tools(logger, organization="huggingface-tools"): | ||||||
|     return tools |     return tools | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def setup_default_tools(logger): | def setup_default_tools(): | ||||||
|     default_tools = {} |     default_tools = {} | ||||||
|     main_module = importlib.import_module("transformers") |     main_module = importlib.import_module("transformers") | ||||||
|     tools_module = main_module.agents |     tools_module = main_module.agents | ||||||
|  |  | ||||||
|  | @ -19,9 +19,8 @@ import re | ||||||
| import numpy as np | import numpy as np | ||||||
| import torch | import torch | ||||||
| 
 | 
 | ||||||
| from ..models.auto import AutoProcessor | from transformers import AutoProcessor, VisionEncoderDecoderModel | ||||||
| from ..models.vision_encoder_decoder import VisionEncoderDecoderModel | from transformers.utils import is_vision_available | ||||||
| from ..utils import is_vision_available |  | ||||||
| from .tools import PipelineTool | from .tools import PipelineTool | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -18,8 +18,8 @@ | ||||||
| import torch | import torch | ||||||
| from PIL import Image | from PIL import Image | ||||||
| 
 | 
 | ||||||
| from ..models.auto import AutoModelForVisualQuestionAnswering, AutoProcessor | from transformers import AutoModelForVisualQuestionAnswering, AutoProcessor | ||||||
| from ..utils import requires_backends | from transformers.utils import requires_backends | ||||||
| from .tools import PipelineTool | from .tools import PipelineTool | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -20,12 +20,11 @@ from typing import Dict, List, Optional | ||||||
| 
 | 
 | ||||||
| from huggingface_hub import InferenceClient | from huggingface_hub import InferenceClient | ||||||
| 
 | 
 | ||||||
| from .. import AutoTokenizer | from transformers import AutoTokenizer, Pipeline | ||||||
| from ..pipelines.base import Pipeline | import logging | ||||||
| from ..utils import logging |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| logger = logging.get_logger(__name__) | logger = logging.getLogger(__name__) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class MessageRole(str, Enum): | class MessageRole(str, Enum): | ||||||
|  |  | ||||||
|  | @ -14,11 +14,8 @@ | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||||
| # limitations under the License. | # limitations under the License. | ||||||
| from ..utils import logging |  | ||||||
| from .agent_types import AgentAudio, AgentImage, AgentText | from .agent_types import AgentAudio, AgentImage, AgentText | ||||||
| 
 | from .utils import console | ||||||
| 
 |  | ||||||
| logger = logging.get_logger(__name__) |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def pull_message(step_log: dict, test_mode: bool = True): | def pull_message(step_log: dict, test_mode: bool = True): | ||||||
|  | @ -107,11 +104,12 @@ class Monitor: | ||||||
|     def update_metrics(self, step_log): |     def update_metrics(self, step_log): | ||||||
|         step_duration = step_log["step_duration"] |         step_duration = step_log["step_duration"] | ||||||
|         self.step_durations.append(step_duration) |         self.step_durations.append(step_duration) | ||||||
|         logger.info(f"Step {len(self.step_durations)}:") |         console.print(f"Step {len(self.step_durations)}:") | ||||||
|         logger.info(f"- Time taken: {step_duration:.2f} seconds (valid only if step succeeded)") |         console.print(f"- Time taken: {step_duration:.2f} seconds") | ||||||
| 
 | 
 | ||||||
|         if getattr(self.tracked_llm_engine, "last_input_token_count", None) is not None: |         if getattr(self.tracked_llm_engine, "last_input_token_count", None) is not None: | ||||||
|             self.total_input_token_count += self.tracked_llm_engine.last_input_token_count |             self.total_input_token_count += self.tracked_llm_engine.last_input_token_count | ||||||
|             self.total_output_token_count += self.tracked_llm_engine.last_output_token_count |             self.total_output_token_count += self.tracked_llm_engine.last_output_token_count | ||||||
|             logger.info(f"- Input tokens: {self.total_input_token_count}") |             console.print(f"- Input tokens: {self.total_input_token_count}") | ||||||
|             logger.info(f"- Output tokens: {self.total_output_token_count}") |             console.print(f"- Output tokens: {self.total_output_token_count}") | ||||||
|  | 
 | ||||||
|  |  | ||||||
|  | @ -16,7 +16,7 @@ | ||||||
| # limitations under the License. | # limitations under the License. | ||||||
| import re | import re | ||||||
| 
 | 
 | ||||||
| from ..utils import cached_file | from transformers.utils import cached_file | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| # docstyle-ignore | # docstyle-ignore | ||||||
|  |  | ||||||
|  | @ -22,12 +22,7 @@ from importlib import import_module | ||||||
| from typing import Any, Callable, Dict, List, Optional | from typing import Any, Callable, Dict, List, Optional | ||||||
| 
 | 
 | ||||||
| import numpy as np | import numpy as np | ||||||
| 
 | import pandas as pd | ||||||
| from ..utils import is_pandas_available |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| if is_pandas_available(): |  | ||||||
|     import pandas as pd |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class InterpreterError(ValueError): | class InterpreterError(ValueError): | ||||||
|  |  | ||||||
|  | @ -15,7 +15,7 @@ | ||||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||||
| # limitations under the License. | # limitations under the License. | ||||||
| 
 | 
 | ||||||
| from ..models.whisper import WhisperForConditionalGeneration, WhisperProcessor | from transformers import WhisperForConditionalGeneration, WhisperProcessor | ||||||
| from .tools import PipelineTool | from .tools import PipelineTool | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -17,8 +17,8 @@ | ||||||
| 
 | 
 | ||||||
| import torch | import torch | ||||||
| 
 | 
 | ||||||
| from ..models.speecht5 import SpeechT5ForTextToSpeech, SpeechT5HifiGan, SpeechT5Processor | from transformers import SpeechT5ForTextToSpeech, SpeechT5HifiGan, SpeechT5Processor | ||||||
| from ..utils import is_datasets_available | from transformers.utils import is_datasets_available | ||||||
| from .tools import PipelineTool | from .tools import PipelineTool | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -30,13 +30,13 @@ from huggingface_hub import create_repo, get_collection, hf_hub_download, metada | ||||||
| from huggingface_hub.utils import RepositoryNotFoundError, build_hf_headers, get_session | from huggingface_hub.utils import RepositoryNotFoundError, build_hf_headers, get_session | ||||||
| from packaging import version | from packaging import version | ||||||
| 
 | 
 | ||||||
| from ..dynamic_module_utils import ( | from transformers.dynamic_module_utils import ( | ||||||
|     custom_object_save, |     custom_object_save, | ||||||
|     get_class_from_dynamic_module, |     get_class_from_dynamic_module, | ||||||
|     get_imports, |     get_imports, | ||||||
| ) | ) | ||||||
| from ..models.auto import AutoProcessor | from transformers import AutoProcessor | ||||||
| from ..utils import ( | from transformers.utils import ( | ||||||
|     CONFIG_NAME, |     CONFIG_NAME, | ||||||
|     TypeHintParsingException, |     TypeHintParsingException, | ||||||
|     cached_file, |     cached_file, | ||||||
|  | @ -44,12 +44,11 @@ from ..utils import ( | ||||||
|     is_accelerate_available, |     is_accelerate_available, | ||||||
|     is_torch_available, |     is_torch_available, | ||||||
|     is_vision_available, |     is_vision_available, | ||||||
|     logging, |  | ||||||
| ) | ) | ||||||
| from .agent_types import ImageType, handle_agent_inputs, handle_agent_outputs | from .agent_types import ImageType, handle_agent_inputs, handle_agent_outputs | ||||||
|  | import logging | ||||||
| 
 | 
 | ||||||
| 
 | logger = logging.getLogger(__name__) | ||||||
| logger = logging.get_logger(__name__) |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| if is_torch_available(): | if is_torch_available(): | ||||||
|  |  | ||||||
|  | @ -14,7 +14,7 @@ | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||||
| # limitations under the License. | # limitations under the License. | ||||||
| from ..models.auto import AutoModelForSeq2SeqLM, AutoTokenizer | from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | ||||||
| from .tools import PipelineTool | from .tools import PipelineTool | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -0,0 +1,10 @@ | ||||||
|  | 
 | ||||||
|  | from transformers.utils.import_utils import _is_package_available | ||||||
|  | 
 | ||||||
|  | _pygments_available = _is_package_available("pygments") | ||||||
|  | 
 | ||||||
|  | def is_pygments_available(): | ||||||
|  |     return _pygments_available | ||||||
|  | 
 | ||||||
|  | from rich.console import Console | ||||||
|  | console = Console() | ||||||
|  | @ -47,10 +47,10 @@ TRANSFORMERS_NO_ADVISORY_WARNINGS=1 ./myprogram.py | ||||||
| Here is an example of how to use the same logger as the library in your own module or script: | Here is an example of how to use the same logger as the library in your own module or script: | ||||||
| 
 | 
 | ||||||
| ```python | ```python | ||||||
| from transformers.utils import logging | import logging | ||||||
| 
 | 
 | ||||||
| logging.set_verbosity_info() | logging.set_verbosity_info() | ||||||
| logger = logging.get_logger("transformers") | logger = logging.getLogger(__name__)("transformers") | ||||||
| logger.info("INFO") | logger.info("INFO") | ||||||
| logger.warning("WARN") | logger.warning("WARN") | ||||||
| ``` | ``` | ||||||
|  | @ -104,7 +104,7 @@ See reference of the `captureWarnings` method below. | ||||||
| 
 | 
 | ||||||
| [[autodoc]] logging.set_verbosity | [[autodoc]] logging.set_verbosity | ||||||
| 
 | 
 | ||||||
| [[autodoc]] logging.get_logger | [[autodoc]] logging.getLogger(__name__) | ||||||
| 
 | 
 | ||||||
| [[autodoc]] logging.enable_default_handler | [[autodoc]] logging.enable_default_handler | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -0,0 +1,20 @@ | ||||||
|  | from agents import load_tool, ReactCodeAgent, HfApiEngine | ||||||
|  | 
 | ||||||
|  | # Import tool from Hub | ||||||
|  | image_generation_tool = load_tool("m-ric/text-to-image", cache=False) | ||||||
|  | 
 | ||||||
|  | # Import tool from LangChain | ||||||
|  | from agents.search import DuckDuckGoSearchTool | ||||||
|  | 
 | ||||||
|  | search_tool = DuckDuckGoSearchTool() | ||||||
|  | 
 | ||||||
|  | llm_engine = HfApiEngine("Qwen/Qwen2.5-72B-Instruct") | ||||||
|  | # Initialize the agent with both tools | ||||||
|  | agent = ReactCodeAgent(tools=[image_generation_tool, search_tool], llm_engine=llm_engine) | ||||||
|  | 
 | ||||||
|  | # Run it! | ||||||
|  | result = agent.run( | ||||||
|  |     "Return me a photo of the car that James bond drove in the latest movie.", | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | print(result) | ||||||
|  | @ -3,3 +3,4 @@ evaluate | ||||||
| datasets==2.3.2 | datasets==2.3.2 | ||||||
| schedulefree | schedulefree | ||||||
| huggingface_hub>=0.20.0 | huggingface_hub>=0.20.0 | ||||||
|  | duckduckgo-search | ||||||
|  | @ -264,6 +264,41 @@ files = [ | ||||||
|     {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, |     {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, | ||||||
| ] | ] | ||||||
| 
 | 
 | ||||||
|  | [[package]] | ||||||
|  | name = "markdown-it-py" | ||||||
|  | version = "3.0.0" | ||||||
|  | description = "Python port of markdown-it. Markdown parsing, done right!" | ||||||
|  | optional = false | ||||||
|  | python-versions = ">=3.8" | ||||||
|  | files = [ | ||||||
|  |     {file = "markdown-it-py-3.0.0.tar.gz", hash = "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb"}, | ||||||
|  |     {file = "markdown_it_py-3.0.0-py3-none-any.whl", hash = "sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1"}, | ||||||
|  | ] | ||||||
|  | 
 | ||||||
|  | [package.dependencies] | ||||||
|  | mdurl = ">=0.1,<1.0" | ||||||
|  | 
 | ||||||
|  | [package.extras] | ||||||
|  | benchmarking = ["psutil", "pytest", "pytest-benchmark"] | ||||||
|  | code-style = ["pre-commit (>=3.0,<4.0)"] | ||||||
|  | compare = ["commonmark (>=0.9,<1.0)", "markdown (>=3.4,<4.0)", "mistletoe (>=1.0,<2.0)", "mistune (>=2.0,<3.0)", "panflute (>=2.3,<3.0)"] | ||||||
|  | linkify = ["linkify-it-py (>=1,<3)"] | ||||||
|  | plugins = ["mdit-py-plugins"] | ||||||
|  | profiling = ["gprof2dot"] | ||||||
|  | rtd = ["jupyter_sphinx", "mdit-py-plugins", "myst-parser", "pyyaml", "sphinx", "sphinx-copybutton", "sphinx-design", "sphinx_book_theme"] | ||||||
|  | testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"] | ||||||
|  | 
 | ||||||
|  | [[package]] | ||||||
|  | name = "mdurl" | ||||||
|  | version = "0.1.2" | ||||||
|  | description = "Markdown URL utilities" | ||||||
|  | optional = false | ||||||
|  | python-versions = ">=3.7" | ||||||
|  | files = [ | ||||||
|  |     {file = "mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8"}, | ||||||
|  |     {file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"}, | ||||||
|  | ] | ||||||
|  | 
 | ||||||
| [[package]] | [[package]] | ||||||
| name = "numpy" | name = "numpy" | ||||||
| version = "2.1.3" | version = "2.1.3" | ||||||
|  | @ -354,6 +389,20 @@ files = [ | ||||||
| dev = ["pre-commit", "tox"] | dev = ["pre-commit", "tox"] | ||||||
| testing = ["pytest", "pytest-benchmark"] | testing = ["pytest", "pytest-benchmark"] | ||||||
| 
 | 
 | ||||||
|  | [[package]] | ||||||
|  | name = "pygments" | ||||||
|  | version = "2.18.0" | ||||||
|  | description = "Pygments is a syntax highlighting package written in Python." | ||||||
|  | optional = false | ||||||
|  | python-versions = ">=3.8" | ||||||
|  | files = [ | ||||||
|  |     {file = "pygments-2.18.0-py3-none-any.whl", hash = "sha256:b8e6aca0523f3ab76fee51799c488e38782ac06eafcf95e7ba832985c8e7b13a"}, | ||||||
|  |     {file = "pygments-2.18.0.tar.gz", hash = "sha256:786ff802f32e91311bff3889f6e9a86e81505fe99f2735bb6d60ae0c5004f199"}, | ||||||
|  | ] | ||||||
|  | 
 | ||||||
|  | [package.extras] | ||||||
|  | windows-terminal = ["colorama (>=0.4.6)"] | ||||||
|  | 
 | ||||||
| [[package]] | [[package]] | ||||||
| name = "pytest" | name = "pytest" | ||||||
| version = "8.3.4" | version = "8.3.4" | ||||||
|  | @ -562,6 +611,25 @@ urllib3 = ">=1.21.1,<3" | ||||||
| socks = ["PySocks (>=1.5.6,!=1.5.7)"] | socks = ["PySocks (>=1.5.6,!=1.5.7)"] | ||||||
| use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] | use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] | ||||||
| 
 | 
 | ||||||
|  | [[package]] | ||||||
|  | name = "rich" | ||||||
|  | version = "13.9.4" | ||||||
|  | description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" | ||||||
|  | optional = false | ||||||
|  | python-versions = ">=3.8.0" | ||||||
|  | files = [ | ||||||
|  |     {file = "rich-13.9.4-py3-none-any.whl", hash = "sha256:6049d5e6ec054bf2779ab3358186963bac2ea89175919d699e378b99738c2a90"}, | ||||||
|  |     {file = "rich-13.9.4.tar.gz", hash = "sha256:439594978a49a09530cff7ebc4b5c7103ef57baf48d5ea3184f21d9a2befa098"}, | ||||||
|  | ] | ||||||
|  | 
 | ||||||
|  | [package.dependencies] | ||||||
|  | markdown-it-py = ">=2.2.0" | ||||||
|  | pygments = ">=2.13.0,<3.0.0" | ||||||
|  | typing-extensions = {version = ">=4.0.0,<5.0", markers = "python_version < \"3.11\""} | ||||||
|  | 
 | ||||||
|  | [package.extras] | ||||||
|  | jupyter = ["ipywidgets (>=7.5.1,<9)"] | ||||||
|  | 
 | ||||||
| [[package]] | [[package]] | ||||||
| name = "safetensors" | name = "safetensors" | ||||||
| version = "0.4.5" | version = "0.4.5" | ||||||
|  | @ -888,4 +956,4 @@ zstd = ["zstandard (>=0.18.0)"] | ||||||
| [metadata] | [metadata] | ||||||
| lock-version = "2.0" | lock-version = "2.0" | ||||||
| python-versions = ">=3.10,<3.13" | python-versions = ">=3.10,<3.13" | ||||||
| content-hash = "de2c60fd8f7c54521b2204b3af144b974059361d7892c70eeb325a7fe52bd489" | content-hash = "985797a96ac1b58e4892479d96bbd1c696b452d760fa421a28b91ea8b3dbf977" | ||||||
|  |  | ||||||
|  | @ -62,6 +62,7 @@ python = ">=3.10,<3.13" | ||||||
| transformers = ">=4.0.0" | transformers = ">=4.0.0" | ||||||
| pytest = {version = ">=8.1.0", optional = true} | pytest = {version = ">=8.1.0", optional = true} | ||||||
| requests = "^2.32.3" | requests = "^2.32.3" | ||||||
|  | rich = "^13.9.4" | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| [build-system] | [build-system] | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue