Support gradio chatbot with continued discussion
This commit is contained in:
		
							parent
							
								
									23ab4a9df3
								
							
						
					
					
						commit
						0ada2ebc27
					
				|  | @ -144,7 +144,7 @@ class AgentImage(AgentType, ImageType): | ||||||
|         if self._raw is not None: |         if self._raw is not None: | ||||||
|             directory = tempfile.mkdtemp() |             directory = tempfile.mkdtemp() | ||||||
|             self._path = os.path.join(directory, str(uuid.uuid4()) + ".png") |             self._path = os.path.join(directory, str(uuid.uuid4()) + ".png") | ||||||
|             self._raw.save(self._path) |             self._raw.save(self._path, format="png") | ||||||
|             return self._path |             return self._path | ||||||
| 
 | 
 | ||||||
|         if self._tensor is not None: |         if self._tensor is not None: | ||||||
|  | @ -155,12 +155,11 @@ class AgentImage(AgentType, ImageType): | ||||||
| 
 | 
 | ||||||
|             directory = tempfile.mkdtemp() |             directory = tempfile.mkdtemp() | ||||||
|             self._path = os.path.join(directory, str(uuid.uuid4()) + ".png") |             self._path = os.path.join(directory, str(uuid.uuid4()) + ".png") | ||||||
| 
 |  | ||||||
|             img.save(self._path, format="png") |             img.save(self._path, format="png") | ||||||
| 
 | 
 | ||||||
|             return self._path |             return self._path | ||||||
| 
 | 
 | ||||||
|     def save(self, output_bytes, format = None, **params): |     def save(self, output_bytes, format : str = None, **params): | ||||||
|         """ |         """ | ||||||
|         Saves the image to a file. |         Saves the image to a file. | ||||||
|         Args: |         Args: | ||||||
|  |  | ||||||
|  | @ -22,7 +22,7 @@ from rich.syntax import Syntax | ||||||
| from transformers.utils import is_torch_available | from transformers.utils import is_torch_available | ||||||
| from .utils import console, parse_code_blob, parse_json_tool_call, truncate_content | from .utils import console, parse_code_blob, parse_json_tool_call, truncate_content | ||||||
| from .agent_types import AgentAudio, AgentImage | from .agent_types import AgentAudio, AgentImage | ||||||
| from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool, setup_default_tools | from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool | ||||||
| from .llm_engine import HfApiEngine, MessageRole | from .llm_engine import HfApiEngine, MessageRole | ||||||
| from .monitoring import Monitor | from .monitoring import Monitor | ||||||
| from .prompts import ( | from .prompts import ( | ||||||
|  | @ -42,13 +42,11 @@ from .tools import ( | ||||||
|     DEFAULT_TOOL_DESCRIPTION_TEMPLATE, |     DEFAULT_TOOL_DESCRIPTION_TEMPLATE, | ||||||
|     Tool, |     Tool, | ||||||
|     get_tool_description_with_args, |     get_tool_description_with_args, | ||||||
|     load_tool, |  | ||||||
|     Toolbox, |     Toolbox, | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| HUGGINGFACE_DEFAULT_TOOLS = {} | HUGGINGFACE_DEFAULT_TOOLS = {} | ||||||
| _tools_are_initialized = False |  | ||||||
| 
 | 
 | ||||||
| class AgentError(Exception): | class AgentError(Exception): | ||||||
|     """Base class for other agent-related exceptions""" |     """Base class for other agent-related exceptions""" | ||||||
|  | @ -101,9 +99,12 @@ class PlanningStep: | ||||||
| 
 | 
 | ||||||
| @dataclass | @dataclass | ||||||
| class TaskStep: | class TaskStep: | ||||||
|     system_prompt: str |  | ||||||
|     task: str |     task: str | ||||||
| 
 | 
 | ||||||
|  | @dataclass | ||||||
|  | class SystemPromptStep: | ||||||
|  |     system_prompt: str | ||||||
|  | 
 | ||||||
| def format_prompt_with_tools(toolbox: Toolbox, prompt_template: str, tool_description_template: str) -> str: | def format_prompt_with_tools(toolbox: Toolbox, prompt_template: str, tool_description_template: str) -> str: | ||||||
|     tool_descriptions = toolbox.show_tool_descriptions(tool_description_template) |     tool_descriptions = toolbox.show_tool_descriptions(tool_description_template) | ||||||
|     prompt = prompt_template.replace("{{tool_descriptions}}", tool_descriptions) |     prompt = prompt_template.replace("{{tool_descriptions}}", tool_descriptions) | ||||||
|  | @ -189,7 +190,7 @@ class BaseAgent: | ||||||
|             self._toolbox, self.system_prompt_template, self.tool_description_template |             self._toolbox, self.system_prompt_template, self.tool_description_template | ||||||
|         ) |         ) | ||||||
|         self.system_prompt = format_prompt_with_managed_agents_descriptions(self.system_prompt, self.managed_agents) |         self.system_prompt = format_prompt_with_managed_agents_descriptions(self.system_prompt, self.managed_agents) | ||||||
|         self.prompt = None |         self.prompt_messages = None | ||||||
|         self.logs = [] |         self.logs = [] | ||||||
|         self.task = None |         self.task = None | ||||||
|         self.verbose = verbose |         self.verbose = verbose | ||||||
|  | @ -208,8 +209,7 @@ class BaseAgent: | ||||||
|         """Get the toolbox currently available to the agent""" |         """Get the toolbox currently available to the agent""" | ||||||
|         return self._toolbox |         return self._toolbox | ||||||
| 
 | 
 | ||||||
|     def initialize_for_run(self): |     def initialize_system_prompt(self): | ||||||
|         self.token_count = 0 |  | ||||||
|         self.system_prompt = format_prompt_with_tools( |         self.system_prompt = format_prompt_with_tools( | ||||||
|             self._toolbox, |             self._toolbox, | ||||||
|             self.system_prompt_template, |             self.system_prompt_template, | ||||||
|  | @ -220,27 +220,25 @@ class BaseAgent: | ||||||
|             self.system_prompt = format_prompt_with_imports( |             self.system_prompt = format_prompt_with_imports( | ||||||
|                 self.system_prompt, list(set(LIST_SAFE_MODULES) | set(self.authorized_imports)) |                 self.system_prompt, list(set(LIST_SAFE_MODULES) | set(self.authorized_imports)) | ||||||
|             ) |             ) | ||||||
|         self.logs = [TaskStep(system_prompt=self.system_prompt, task=self.task)] | 
 | ||||||
|         console.rule("[bold]New task", characters='=') |         return self.system_prompt | ||||||
|         console.print(self.task) |  | ||||||
| 
 | 
 | ||||||
|     def write_inner_memory_from_logs(self, summary_mode: Optional[bool] = False) -> List[Dict[str, str]]: |     def write_inner_memory_from_logs(self, summary_mode: Optional[bool] = False) -> List[Dict[str, str]]: | ||||||
|         """ |         """ | ||||||
|         Reads past llm_outputs, actions, and observations or errors from the logs into a series of messages |         Reads past llm_outputs, actions, and observations or errors from the logs into a series of messages | ||||||
|         that can be used as input to the LLM. |         that can be used as input to the LLM. | ||||||
|         """ |         """ | ||||||
|         prompt_message = {"role": MessageRole.SYSTEM, "content": self.logs[0].system_prompt} |         memory = [] | ||||||
|         task_message = { |         for i, step_log in enumerate(self.logs): | ||||||
|             "role": MessageRole.USER, |             if isinstance(step_log, SystemPromptStep): | ||||||
|             "content": "Task: " + self.logs[0].task, |                 if not summary_mode: | ||||||
|         } |                     thought_message = { | ||||||
|         if summary_mode: |                         "role": MessageRole.SYSTEM, | ||||||
|             memory = [task_message] |                         "content": step_log.system_prompt.strip(), | ||||||
|         else: |                     } | ||||||
|             memory = [prompt_message, task_message] |                     memory.append(thought_message) | ||||||
|         for i, step_log in enumerate(self.logs[1:]): |  | ||||||
| 
 | 
 | ||||||
|             if isinstance(step_log, PlanningStep): |             elif isinstance(step_log, PlanningStep): | ||||||
|                 thought_message = { |                 thought_message = { | ||||||
|                     "role": MessageRole.ASSISTANT, |                     "role": MessageRole.ASSISTANT, | ||||||
|                     "content": "[FACTS LIST]:\n" + step_log.facts.strip(), |                     "content": "[FACTS LIST]:\n" + step_log.facts.strip(), | ||||||
|  | @ -398,21 +396,21 @@ class ReactAgent(BaseAgent): | ||||||
|         """ |         """ | ||||||
|         This method provides a final answer to the task, based on the logs of the agent's interactions. |         This method provides a final answer to the task, based on the logs of the agent's interactions. | ||||||
|         """ |         """ | ||||||
|         self.prompt = [ |         self.prompt_messages = [ | ||||||
|             { |             { | ||||||
|                 "role": MessageRole.SYSTEM, |                 "role": MessageRole.SYSTEM, | ||||||
|                 "content": "An agent tried to answer a user query but it got stuck and failed to do so. You are tasked with providing an answer instead. Here is the agent's memory:", |                 "content": "An agent tried to answer a user query but it got stuck and failed to do so. You are tasked with providing an answer instead. Here is the agent's memory:", | ||||||
|             } |             } | ||||||
|         ] |         ] | ||||||
|         self.prompt += self.write_inner_memory_from_logs()[1:] |         self.prompt_messages += self.write_inner_memory_from_logs()[1:] | ||||||
|         self.prompt += [ |         self.prompt_messages += [ | ||||||
|             { |             { | ||||||
|                 "role": MessageRole.USER, |                 "role": MessageRole.USER, | ||||||
|                 "content": f"Based on the above, please provide an answer to the following user request:\n{task}", |                 "content": f"Based on the above, please provide an answer to the following user request:\n{task}", | ||||||
|             } |             } | ||||||
|         ] |         ] | ||||||
|         try: |         try: | ||||||
|             return self.llm_engine(self.prompt) |             return self.llm_engine(self.prompt_messages) | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             error_msg = f"Error in generating final LLM output: {e}." |             error_msg = f"Error in generating final LLM output: {e}." | ||||||
|             console.print(f"[bold red]{error_msg}[/bold red]") |             console.print(f"[bold red]{error_msg}[/bold red]") | ||||||
|  | @ -423,7 +421,10 @@ class ReactAgent(BaseAgent): | ||||||
|         Runs the agent for the given task. |         Runs the agent for the given task. | ||||||
| 
 | 
 | ||||||
|         Args: |         Args: | ||||||
|             task (`str`): The task to perform |             task (`str`): The task to perform. | ||||||
|  |             stream (`bool`): Wether to run in a streaming way. | ||||||
|  |             reset (`bool`): Wether to reset the conversation or keep it going from previous run. | ||||||
|  |             oneshot (`bool`): Should the agent run in one shot or multi-step fashion? | ||||||
| 
 | 
 | ||||||
|         Example: |         Example: | ||||||
|         ```py |         ```py | ||||||
|  | @ -436,10 +437,23 @@ class ReactAgent(BaseAgent): | ||||||
|         if len(kwargs) > 0: |         if len(kwargs) > 0: | ||||||
|             self.task += f"\nYou have been provided with these initial arguments: {str(kwargs)}." |             self.task += f"\nYou have been provided with these initial arguments: {str(kwargs)}." | ||||||
|         self.state = kwargs.copy() |         self.state = kwargs.copy() | ||||||
|  | 
 | ||||||
|  |         self.initialize_system_prompt() | ||||||
|  |         system_prompt_step = SystemPromptStep(system_prompt=self.system_prompt) | ||||||
|  | 
 | ||||||
|         if reset: |         if reset: | ||||||
|             self.initialize_for_run() |             self.token_count = 0 | ||||||
|  |             self.logs = [] | ||||||
|  |             self.logs.append(system_prompt_step) | ||||||
|         else: |         else: | ||||||
|             self.logs.append(TaskStep(task=task)) |             if len(self.logs) > 0: | ||||||
|  |                 self.logs[0] = system_prompt_step | ||||||
|  |             else: | ||||||
|  |                 self.logs.append(system_prompt_step) | ||||||
|  | 
 | ||||||
|  |         console.rule("[bold]New task", characters='=') | ||||||
|  |         console.print(self.task) | ||||||
|  |         self.logs.append(TaskStep(task=task)) | ||||||
| 
 | 
 | ||||||
|         if oneshot: |         if oneshot: | ||||||
|             step_start_time = time.time() |             step_start_time = time.time() | ||||||
|  | @ -676,20 +690,20 @@ class JsonAgent(ReactAgent): | ||||||
|         """ |         """ | ||||||
|         agent_memory = self.write_inner_memory_from_logs() |         agent_memory = self.write_inner_memory_from_logs() | ||||||
| 
 | 
 | ||||||
|         self.prompt = agent_memory |         self.prompt_messages = agent_memory | ||||||
| 
 | 
 | ||||||
|         # Add new step in logs |         # Add new step in logs | ||||||
|         log_entry.agent_memory = agent_memory.copy() |         log_entry.agent_memory = agent_memory.copy() | ||||||
| 
 | 
 | ||||||
|         if self.verbose: |         if self.verbose: | ||||||
|             console.rule("[italic]Calling LLM engine with this last message:", align="left") |             console.rule("[italic]Calling LLM engine with this last message:", align="left") | ||||||
|             console.print(self.prompt[-1]) |             console.print(self.prompt_messages[-1]) | ||||||
|             console.rule() |             console.rule() | ||||||
| 
 | 
 | ||||||
|         try: |         try: | ||||||
|             additional_args = {"grammar": self.grammar} if self.grammar is not None else {} |             additional_args = {"grammar": self.grammar} if self.grammar is not None else {} | ||||||
|             llm_output = self.llm_engine( |             llm_output = self.llm_engine( | ||||||
|                 self.prompt, stop_sequences=["<end_action>", "Observation:"], **additional_args |                 self.prompt_messages, stop_sequences=["<end_action>", "Observation:"], **additional_args | ||||||
|             ) |             ) | ||||||
|             log_entry.llm_output = llm_output |             log_entry.llm_output = llm_output | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|  | @ -796,20 +810,20 @@ class CodeAgent(ReactAgent): | ||||||
|         """ |         """ | ||||||
|         agent_memory = self.write_inner_memory_from_logs() |         agent_memory = self.write_inner_memory_from_logs() | ||||||
| 
 | 
 | ||||||
|         self.prompt = agent_memory.copy() |         self.prompt_messages = agent_memory.copy() | ||||||
| 
 | 
 | ||||||
|         # Add new step in logs |         # Add new step in logs | ||||||
|         log_entry.agent_memory = agent_memory.copy() |         log_entry.agent_memory = agent_memory.copy() | ||||||
| 
 | 
 | ||||||
|         if self.verbose: |         if self.verbose: | ||||||
|             console.rule("[italic]Calling LLM engine with these last messages:", align="left") |             console.rule("[italic]Calling LLM engine with these last messages:", align="left") | ||||||
|             console.print(self.prompt[-2:]) |             console.print(self.prompt_messages[-2:]) | ||||||
|             console.rule() |             console.rule() | ||||||
| 
 | 
 | ||||||
|         try: |         try: | ||||||
|             additional_args = {"grammar": self.grammar} if self.grammar is not None else {} |             additional_args = {"grammar": self.grammar} if self.grammar is not None else {} | ||||||
|             llm_output = self.llm_engine( |             llm_output = self.llm_engine( | ||||||
|                 self.prompt, stop_sequences=["<end_action>", "Observation:"], **additional_args |                 self.prompt_messages, stop_sequences=["<end_action>", "Observation:"], **additional_args | ||||||
|             ) |             ) | ||||||
|             log_entry.llm_output = llm_output |             log_entry.llm_output = llm_output | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|  | @ -893,7 +907,7 @@ You have been submitted this task by your manager. | ||||||
| Task: | Task: | ||||||
| {task} | {task} | ||||||
| --- | --- | ||||||
| You're helping your manager solve a wider task: so make sure to not provide a one-line answer, but give as much information as possible so that they have a clear understanding of the answer. | You're helping your manager solve a wider task: so make sure to not provide a one-line answer, but give as much information as possible to give them a clear understanding of the answer. | ||||||
| 
 | 
 | ||||||
| Your final_answer WILL HAVE to contain these parts: | Your final_answer WILL HAVE to contain these parts: | ||||||
| ### 1. Task outcome (short version): | ### 1. Task outcome (short version): | ||||||
|  |  | ||||||
|  | @ -14,7 +14,6 @@ | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||||
| # limitations under the License. | # limitations under the License. | ||||||
| import importlib.util |  | ||||||
| import json | import json | ||||||
| import math | import math | ||||||
| from dataclasses import dataclass | from dataclasses import dataclass | ||||||
|  | @ -25,7 +24,7 @@ from huggingface_hub import hf_hub_download, list_spaces | ||||||
| 
 | 
 | ||||||
| from transformers.utils import is_offline_mode | from transformers.utils import is_offline_mode | ||||||
| from .python_interpreter import LIST_SAFE_MODULES, evaluate_python_code | from .python_interpreter import LIST_SAFE_MODULES, evaluate_python_code | ||||||
| from .tools import TOOL_CONFIG_FILE, TOOL_MAPPING, Tool | from .tools import TOOL_CONFIG_FILE, Tool | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def custom_print(*args): | def custom_print(*args): | ||||||
|  | @ -97,12 +96,6 @@ class PreTool: | ||||||
|     repo_id: str |     repo_id: str | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| HUGGINGFACE_DEFAULT_TOOLS_FROM_HUB = [ |  | ||||||
|     "image-transformation", |  | ||||||
|     "text-to-image", |  | ||||||
| ] |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| def get_remote_tools(logger, organization="huggingface-tools"): | def get_remote_tools(logger, organization="huggingface-tools"): | ||||||
|     if is_offline_mode(): |     if is_offline_mode(): | ||||||
|         logger.info("You are in offline mode, so remote tools are not available.") |         logger.info("You are in offline mode, so remote tools are not available.") | ||||||
|  | @ -128,26 +121,6 @@ def get_remote_tools(logger, organization="huggingface-tools"): | ||||||
|     return tools |     return tools | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def setup_default_tools(): |  | ||||||
|     default_tools = {} |  | ||||||
|     main_module = importlib.import_module("transformers") |  | ||||||
|     tools_module = main_module.agents |  | ||||||
| 
 |  | ||||||
|     for task_name, tool_class_name in TOOL_MAPPING.items(): |  | ||||||
|         tool_class = getattr(tools_module, tool_class_name) |  | ||||||
|         tool_instance = tool_class() |  | ||||||
|         default_tools[tool_class.name] = PreTool( |  | ||||||
|             name=tool_instance.name, |  | ||||||
|             inputs=tool_instance.inputs, |  | ||||||
|             output_type=tool_instance.output_type, |  | ||||||
|             task=task_name, |  | ||||||
|             description=tool_instance.description, |  | ||||||
|             repo_id=None, |  | ||||||
|         ) |  | ||||||
| 
 |  | ||||||
|     return default_tools |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| class PythonInterpreterTool(Tool): | class PythonInterpreterTool(Tool): | ||||||
|     name = "python_interpreter" |     name = "python_interpreter" | ||||||
|     description = "This is a tool that evaluates python code. It can be used to perform calculations." |     description = "This is a tool that evaluates python code. It can be used to perform calculations." | ||||||
|  |  | ||||||
|  | @ -1,88 +0,0 @@ | ||||||
| #!/usr/bin/env python |  | ||||||
| # coding=utf-8 |  | ||||||
| 
 |  | ||||||
| # Copyright 2023 The HuggingFace Inc. team. All rights reserved. |  | ||||||
| # |  | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); |  | ||||||
| # you may not use this file except in compliance with the License. |  | ||||||
| # You may obtain a copy of the License at |  | ||||||
| # |  | ||||||
| #     http://www.apache.org/licenses/LICENSE-2.0 |  | ||||||
| # |  | ||||||
| # Unless required by applicable law or agreed to in writing, software |  | ||||||
| # distributed under the License is distributed on an "AS IS" BASIS, |  | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |  | ||||||
| # See the License for the specific language governing permissions and |  | ||||||
| # limitations under the License. |  | ||||||
| import re |  | ||||||
| 
 |  | ||||||
| import numpy as np |  | ||||||
| import torch |  | ||||||
| 
 |  | ||||||
| from transformers import AutoProcessor, VisionEncoderDecoderModel |  | ||||||
| from transformers.utils import is_vision_available |  | ||||||
| from .tools import PipelineTool |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| if is_vision_available(): |  | ||||||
|     from PIL import Image |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| class DocumentQuestionAnsweringTool(PipelineTool): |  | ||||||
|     default_checkpoint = "naver-clova-ix/donut-base-finetuned-docvqa" |  | ||||||
|     description = "This is a tool that answers a question about an document (pdf). It returns a string that contains the answer to the question." |  | ||||||
|     name = "document_qa" |  | ||||||
|     pre_processor_class = AutoProcessor |  | ||||||
|     model_class = VisionEncoderDecoderModel |  | ||||||
| 
 |  | ||||||
|     inputs = { |  | ||||||
|         "document": { |  | ||||||
|             "type": "image", |  | ||||||
|             "description": "The image containing the information. Can be a PIL Image or a string path to the image.", |  | ||||||
|         }, |  | ||||||
|         "question": {"type": "string", "description": "The question in English"}, |  | ||||||
|     } |  | ||||||
|     output_type = "string" |  | ||||||
| 
 |  | ||||||
|     def __init__(self, *args, **kwargs): |  | ||||||
|         if not is_vision_available(): |  | ||||||
|             raise ValueError("Pillow must be installed to use the DocumentQuestionAnsweringTool.") |  | ||||||
| 
 |  | ||||||
|         super().__init__(*args, **kwargs) |  | ||||||
| 
 |  | ||||||
|     def encode(self, document: "Image", question: str): |  | ||||||
|         task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>" |  | ||||||
|         prompt = task_prompt.replace("{user_input}", question) |  | ||||||
|         decoder_input_ids = self.pre_processor.tokenizer( |  | ||||||
|             prompt, add_special_tokens=False, return_tensors="pt" |  | ||||||
|         ).input_ids |  | ||||||
|         if isinstance(document, str): |  | ||||||
|             img = Image.open(document).convert("RGB") |  | ||||||
|             img_array = np.array(img).transpose(2, 0, 1) |  | ||||||
|             document = torch.from_numpy(img_array) |  | ||||||
|         pixel_values = self.pre_processor(document, return_tensors="pt").pixel_values |  | ||||||
| 
 |  | ||||||
|         return {"decoder_input_ids": decoder_input_ids, "pixel_values": pixel_values} |  | ||||||
| 
 |  | ||||||
|     def forward(self, inputs): |  | ||||||
|         return self.model.generate( |  | ||||||
|             inputs["pixel_values"].to(self.device), |  | ||||||
|             decoder_input_ids=inputs["decoder_input_ids"].to(self.device), |  | ||||||
|             max_length=self.model.decoder.config.max_position_embeddings, |  | ||||||
|             early_stopping=True, |  | ||||||
|             pad_token_id=self.pre_processor.tokenizer.pad_token_id, |  | ||||||
|             eos_token_id=self.pre_processor.tokenizer.eos_token_id, |  | ||||||
|             use_cache=True, |  | ||||||
|             num_beams=1, |  | ||||||
|             bad_words_ids=[[self.pre_processor.tokenizer.unk_token_id]], |  | ||||||
|             return_dict_in_generate=True, |  | ||||||
|         ).sequences |  | ||||||
| 
 |  | ||||||
|     def decode(self, outputs): |  | ||||||
|         sequence = self.pre_processor.batch_decode(outputs)[0] |  | ||||||
|         sequence = sequence.replace(self.pre_processor.tokenizer.eos_token, "") |  | ||||||
|         sequence = sequence.replace(self.pre_processor.tokenizer.pad_token, "") |  | ||||||
|         sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()  # remove first task start token |  | ||||||
|         sequence = self.pre_processor.token2json(sequence) |  | ||||||
| 
 |  | ||||||
|         return sequence["answer"] |  | ||||||
|  | @ -1,58 +0,0 @@ | ||||||
| #!/usr/bin/env python |  | ||||||
| # coding=utf-8 |  | ||||||
| 
 |  | ||||||
| # Copyright 2023 The HuggingFace Inc. team. All rights reserved. |  | ||||||
| # |  | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); |  | ||||||
| # you may not use this file except in compliance with the License. |  | ||||||
| # You may obtain a copy of the License at |  | ||||||
| # |  | ||||||
| #     http://www.apache.org/licenses/LICENSE-2.0 |  | ||||||
| # |  | ||||||
| # Unless required by applicable law or agreed to in writing, software |  | ||||||
| # distributed under the License is distributed on an "AS IS" BASIS, |  | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |  | ||||||
| # See the License for the specific language governing permissions and |  | ||||||
| # limitations under the License. |  | ||||||
| 
 |  | ||||||
| import torch |  | ||||||
| from PIL import Image |  | ||||||
| 
 |  | ||||||
| from transformers import AutoModelForVisualQuestionAnswering, AutoProcessor |  | ||||||
| from transformers.utils import requires_backends |  | ||||||
| from .tools import PipelineTool |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| class ImageQuestionAnsweringTool(PipelineTool): |  | ||||||
|     default_checkpoint = "dandelin/vilt-b32-finetuned-vqa" |  | ||||||
|     description = ( |  | ||||||
|         "This is a tool that answers a question about an image. It " |  | ||||||
|         "returns a text that is the answer to the question." |  | ||||||
|     ) |  | ||||||
|     name = "image_qa" |  | ||||||
|     pre_processor_class = AutoProcessor |  | ||||||
|     model_class = AutoModelForVisualQuestionAnswering |  | ||||||
| 
 |  | ||||||
|     inputs = { |  | ||||||
|         "image": { |  | ||||||
|             "type": "image", |  | ||||||
|             "description": "The image containing the information. Can be a PIL Image or a string path to the image.", |  | ||||||
|         }, |  | ||||||
|         "question": {"type": "string", "description": "The question in English"}, |  | ||||||
|     } |  | ||||||
|     output_type = "string" |  | ||||||
| 
 |  | ||||||
|     def __init__(self, *args, **kwargs): |  | ||||||
|         requires_backends(self, ["vision"]) |  | ||||||
|         super().__init__(*args, **kwargs) |  | ||||||
| 
 |  | ||||||
|     def encode(self, image: "Image", question: str): |  | ||||||
|         return self.pre_processor(image, question, return_tensors="pt") |  | ||||||
| 
 |  | ||||||
|     def forward(self, inputs): |  | ||||||
|         with torch.no_grad(): |  | ||||||
|             return self.model(**inputs).logits |  | ||||||
| 
 |  | ||||||
|     def decode(self, outputs): |  | ||||||
|         idx = outputs.argmax(-1).item() |  | ||||||
|         return self.model.config.id2label[idx] |  | ||||||
|  | @ -17,20 +17,8 @@ | ||||||
| from .agent_types import AgentAudio, AgentImage, AgentText | from .agent_types import AgentAudio, AgentImage, AgentText | ||||||
| from .utils import console | from .utils import console | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| def pull_message(step_log: dict, test_mode: bool = True): | def pull_message(step_log: dict, test_mode: bool = True): | ||||||
|     try: |     from gradio import ChatMessage | ||||||
|         from gradio import ChatMessage |  | ||||||
|     except ImportError: |  | ||||||
|         if test_mode: |  | ||||||
| 
 |  | ||||||
|             class ChatMessage: |  | ||||||
|                 def __init__(self, role, content, metadata=None): |  | ||||||
|                     self.role = role |  | ||||||
|                     self.content = content |  | ||||||
|                     self.metadata = metadata |  | ||||||
|         else: |  | ||||||
|             raise ImportError("Gradio should be installed in order to launch a gradio demo.") |  | ||||||
| 
 | 
 | ||||||
|     if step_log.get("rationale"): |     if step_log.get("rationale"): | ||||||
|         yield ChatMessage(role="assistant", content=step_log["rationale"]) |         yield ChatMessage(role="assistant", content=step_log["rationale"]) | ||||||
|  | @ -54,23 +42,11 @@ def pull_message(step_log: dict, test_mode: bool = True): | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def stream_to_gradio(agent, task: str, test_mode: bool = False, **kwargs): | def stream_to_gradio(agent, task: str, test_mode: bool = False, reset_agent_memory: bool=False, **kwargs): | ||||||
|     """Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages.""" |     """Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages.""" | ||||||
|  |     from gradio import ChatMessage | ||||||
| 
 | 
 | ||||||
|     try: |     for step_log in agent.run(task, stream=True, reset=reset_agent_memory, **kwargs): | ||||||
|         from gradio import ChatMessage |  | ||||||
|     except ImportError: |  | ||||||
|         if test_mode: |  | ||||||
| 
 |  | ||||||
|             class ChatMessage: |  | ||||||
|                 def __init__(self, role, content, metadata=None): |  | ||||||
|                     self.role = role |  | ||||||
|                     self.content = content |  | ||||||
|                     self.metadata = metadata |  | ||||||
|         else: |  | ||||||
|             raise ImportError("Gradio should be installed in order to launch a gradio demo.") |  | ||||||
| 
 |  | ||||||
|     for step_log in agent.run(task, stream=True, **kwargs): |  | ||||||
|         if isinstance(step_log, dict): |         if isinstance(step_log, dict): | ||||||
|             for message in pull_message(step_log, test_mode=test_mode): |             for message in pull_message(step_log, test_mode=test_mode): | ||||||
|                 yield message |                 yield message | ||||||
|  |  | ||||||
|  | @ -395,7 +395,7 @@ Do not add anything else.""" | ||||||
| SYSTEM_PROMPT_PLAN = """You are a world expert at making efficient plans to solve any task using a set of carefully crafted tools. | SYSTEM_PROMPT_PLAN = """You are a world expert at making efficient plans to solve any task using a set of carefully crafted tools. | ||||||
| 
 | 
 | ||||||
| Now for the given task, develop a step-by-step high-level plan taking into account the above inputs and list of facts. | Now for the given task, develop a step-by-step high-level plan taking into account the above inputs and list of facts. | ||||||
| This plan should involve individual tasks based on the avilable tools, that if executed correctly will yield the correct answer. | This plan should involve individual tasks based on the available tools, that if executed correctly will yield the correct answer. | ||||||
| Do not skip steps, do not add any superfluous steps. Only write the high-level plan, DO NOT DETAIL INDIVIDUAL TOOL CALLS. | Do not skip steps, do not add any superfluous steps. Only write the high-level plan, DO NOT DETAIL INDIVIDUAL TOOL CALLS. | ||||||
| After writing the final step of the plan, write the '\n<end_plan>' tag and stop there.""" | After writing the final step of the plan, write the '\n<end_plan>' tag and stop there.""" | ||||||
| 
 | 
 | ||||||
|  | @ -466,7 +466,7 @@ Here is the up to date list of facts that you know: | ||||||
| ``` | ``` | ||||||
| 
 | 
 | ||||||
| Now for the given task, develop a step-by-step high-level plan taking into account the above inputs and list of facts. | Now for the given task, develop a step-by-step high-level plan taking into account the above inputs and list of facts. | ||||||
| This plan should involve individual tasks based on the avilable tools, that if executed correctly will yield the correct answer. | This plan should involve individual tasks based on the available tools, that if executed correctly will yield the correct answer. | ||||||
| Beware that you have {remaining_steps} steps remaining. | Beware that you have {remaining_steps} steps remaining. | ||||||
| Do not skip steps, do not add any superfluous steps. Only write the high-level plan, DO NOT DETAIL INDIVIDUAL TOOL CALLS. | Do not skip steps, do not add any superfluous steps. Only write the high-level plan, DO NOT DETAIL INDIVIDUAL TOOL CALLS. | ||||||
| After writing the final step of the plan, write the '\n<end_plan>' tag and stop there. | After writing the final step of the plan, write the '\n<end_plan>' tag and stop there. | ||||||
|  |  | ||||||
|  | @ -1,39 +0,0 @@ | ||||||
| #!/usr/bin/env python |  | ||||||
| # coding=utf-8 |  | ||||||
| 
 |  | ||||||
| # Copyright 2024 The HuggingFace Inc. team. All rights reserved. |  | ||||||
| # |  | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); |  | ||||||
| # you may not use this file except in compliance with the License. |  | ||||||
| # You may obtain a copy of the License at |  | ||||||
| # |  | ||||||
| #     http://www.apache.org/licenses/LICENSE-2.0 |  | ||||||
| # |  | ||||||
| # Unless required by applicable law or agreed to in writing, software |  | ||||||
| # distributed under the License is distributed on an "AS IS" BASIS, |  | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |  | ||||||
| # See the License for the specific language governing permissions and |  | ||||||
| # limitations under the License. |  | ||||||
| 
 |  | ||||||
| from transformers import WhisperForConditionalGeneration, WhisperProcessor |  | ||||||
| from .tools import PipelineTool |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| class SpeechToTextTool(PipelineTool): |  | ||||||
|     default_checkpoint = "distil-whisper/distil-large-v3" |  | ||||||
|     description = "This is a tool that transcribes an audio into text. It returns the transcribed text." |  | ||||||
|     name = "transcriber" |  | ||||||
|     pre_processor_class = WhisperProcessor |  | ||||||
|     model_class = WhisperForConditionalGeneration |  | ||||||
| 
 |  | ||||||
|     inputs = {"audio": {"type": "audio", "description": "The audio to transcribe"}} |  | ||||||
|     output_type = "string" |  | ||||||
| 
 |  | ||||||
|     def encode(self, audio): |  | ||||||
|         return self.pre_processor(audio, return_tensors="pt") |  | ||||||
| 
 |  | ||||||
|     def forward(self, inputs): |  | ||||||
|         return self.model.generate(inputs["input_features"]) |  | ||||||
| 
 |  | ||||||
|     def decode(self, outputs): |  | ||||||
|         return self.pre_processor.batch_decode(outputs, skip_special_tokens=True)[0] |  | ||||||
|  | @ -1,67 +0,0 @@ | ||||||
| #!/usr/bin/env python |  | ||||||
| # coding=utf-8 |  | ||||||
| 
 |  | ||||||
| # Copyright 2024 The HuggingFace Inc. team. All rights reserved. |  | ||||||
| # |  | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); |  | ||||||
| # you may not use this file except in compliance with the License. |  | ||||||
| # You may obtain a copy of the License at |  | ||||||
| # |  | ||||||
| #     http://www.apache.org/licenses/LICENSE-2.0 |  | ||||||
| # |  | ||||||
| # Unless required by applicable law or agreed to in writing, software |  | ||||||
| # distributed under the License is distributed on an "AS IS" BASIS, |  | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |  | ||||||
| # See the License for the specific language governing permissions and |  | ||||||
| # limitations under the License. |  | ||||||
| 
 |  | ||||||
| import torch |  | ||||||
| 
 |  | ||||||
| from transformers import SpeechT5ForTextToSpeech, SpeechT5HifiGan, SpeechT5Processor |  | ||||||
| from transformers.utils import is_datasets_available |  | ||||||
| from .tools import PipelineTool |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| if is_datasets_available(): |  | ||||||
|     from datasets import load_dataset |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| class TextToSpeechTool(PipelineTool): |  | ||||||
|     default_checkpoint = "microsoft/speecht5_tts" |  | ||||||
|     description = ( |  | ||||||
|         "This is a tool that reads an English text out loud. It returns a waveform object containing the sound." |  | ||||||
|     ) |  | ||||||
|     name = "text_to_speech" |  | ||||||
|     pre_processor_class = SpeechT5Processor |  | ||||||
|     model_class = SpeechT5ForTextToSpeech |  | ||||||
|     post_processor_class = SpeechT5HifiGan |  | ||||||
| 
 |  | ||||||
|     inputs = {"text": {"type": "string", "description": "The text to read out loud (in English)"}} |  | ||||||
|     output_type = "audio" |  | ||||||
| 
 |  | ||||||
|     def setup(self): |  | ||||||
|         if self.post_processor is None: |  | ||||||
|             self.post_processor = "microsoft/speecht5_hifigan" |  | ||||||
|         super().setup() |  | ||||||
| 
 |  | ||||||
|     def encode(self, text, speaker_embeddings=None): |  | ||||||
|         inputs = self.pre_processor(text=text, return_tensors="pt", truncation=True) |  | ||||||
| 
 |  | ||||||
|         if speaker_embeddings is None: |  | ||||||
|             if not is_datasets_available(): |  | ||||||
|                 raise ImportError("Datasets needs to be installed if not passing speaker embeddings.") |  | ||||||
| 
 |  | ||||||
|             embeddings_dataset = load_dataset( |  | ||||||
|                 "Matthijs/cmu-arctic-xvectors", split="validation", trust_remote_code=True |  | ||||||
|             ) |  | ||||||
|             speaker_embeddings = torch.tensor(embeddings_dataset[7305]["xvector"]).unsqueeze(0) |  | ||||||
| 
 |  | ||||||
|         return {"input_ids": inputs["input_ids"], "speaker_embeddings": speaker_embeddings} |  | ||||||
| 
 |  | ||||||
|     def forward(self, inputs): |  | ||||||
|         with torch.no_grad(): |  | ||||||
|             return self.model.generate_speech(**inputs) |  | ||||||
| 
 |  | ||||||
|     def decode(self, outputs): |  | ||||||
|         with torch.no_grad(): |  | ||||||
|             return self.post_processor(outputs).cpu().detach() |  | ||||||
|  | @ -80,6 +80,19 @@ def get_repo_type(repo_id, repo_type=None, **hub_kwargs): | ||||||
|         return "space" |         return "space" | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | def setup_default_tools(): | ||||||
|  |     default_tools = {} | ||||||
|  |     main_module = importlib.import_module("transformers") | ||||||
|  |     tools_module = main_module.agents | ||||||
|  | 
 | ||||||
|  |     for task_name, tool_class_name in TOOL_MAPPING.items(): | ||||||
|  |         tool_class = getattr(tools_module, tool_class_name) | ||||||
|  |         tool_instance = tool_class() | ||||||
|  |         default_tools[tool_class.name] = tool_instance | ||||||
|  | 
 | ||||||
|  |     return default_tools | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| # docstyle-ignore | # docstyle-ignore | ||||||
| APP_FILE_TEMPLATE = """from transformers import launch_gradio_demo | APP_FILE_TEMPLATE = """from transformers import launch_gradio_demo | ||||||
| from {module_name} import {class_name} | from {module_name} import {class_name} | ||||||
|  | @ -811,11 +824,6 @@ def launch_gradio_demo(tool_class: Tool): | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| TOOL_MAPPING = { | TOOL_MAPPING = { | ||||||
|     "document_question_answering": "DocumentQuestionAnsweringTool", |  | ||||||
|     "image_question_answering": "ImageQuestionAnsweringTool", |  | ||||||
|     "speech_to_text": "SpeechToTextTool", |  | ||||||
|     "text_to_speech": "TextToSpeechTool", |  | ||||||
|     "translation": "TranslationTool", |  | ||||||
|     "python_interpreter": "PythonInterpreterTool", |     "python_interpreter": "PythonInterpreterTool", | ||||||
|     "web_search": "DuckDuckGoSearchTool", |     "web_search": "DuckDuckGoSearchTool", | ||||||
| } | } | ||||||
|  | @ -1018,18 +1026,14 @@ class Toolbox: | ||||||
|         self._tools = {tool.name: tool for tool in tools} |         self._tools = {tool.name: tool for tool in tools} | ||||||
|         if add_base_tools: |         if add_base_tools: | ||||||
|             self.add_base_tools() |             self.add_base_tools() | ||||||
|         # self._load_tools_if_needed() |  | ||||||
| 
 | 
 | ||||||
|     def add_base_tools(self, add_python_interpreter: bool = False): |     def add_base_tools(self, add_python_interpreter: bool = False): | ||||||
|         global _tools_are_initialized |  | ||||||
|         global HUGGINGFACE_DEFAULT_TOOLS |         global HUGGINGFACE_DEFAULT_TOOLS | ||||||
|         if not _tools_are_initialized: |         if len(HUGGINGFACE_DEFAULT_TOOLS.keys()) == 0: | ||||||
|             HUGGINGFACE_DEFAULT_TOOLS = setup_default_tools() |             HUGGINGFACE_DEFAULT_TOOLS = setup_default_tools() | ||||||
|             _tools_are_initialized = True |  | ||||||
|         for tool in HUGGINGFACE_DEFAULT_TOOLS.values(): |         for tool in HUGGINGFACE_DEFAULT_TOOLS.values(): | ||||||
|             if tool.name != "python_interpreter" or add_python_interpreter: |             if tool.name != "python_interpreter" or add_python_interpreter: | ||||||
|                 self.add_tool(tool) |                 self.add_tool(tool) | ||||||
|         # self._load_tools_if_needed() |  | ||||||
| 
 | 
 | ||||||
|     @property |     @property | ||||||
|     def tools(self) -> Dict[str, Tool]: |     def tools(self) -> Dict[str, Tool]: | ||||||
|  |  | ||||||
|  | @ -1,4 +1,4 @@ | ||||||
| from agents import load_tool, ReactCodeAgent, HfApiEngine | from agents import load_tool, CodeAgent, HfApiEngine | ||||||
| 
 | 
 | ||||||
| # Import tool from Hub | # Import tool from Hub | ||||||
| image_generation_tool = load_tool("m-ric/text-to-image", cache=False) | image_generation_tool = load_tool("m-ric/text-to-image", cache=False) | ||||||
|  | @ -10,7 +10,7 @@ search_tool = DuckDuckGoSearchTool() | ||||||
| 
 | 
 | ||||||
| llm_engine = HfApiEngine("Qwen/Qwen2.5-72B-Instruct") | llm_engine = HfApiEngine("Qwen/Qwen2.5-72B-Instruct") | ||||||
| # Initialize the agent with both tools | # Initialize the agent with both tools | ||||||
| agent = ReactCodeAgent(tools=[image_generation_tool, search_tool], llm_engine=llm_engine) | agent = CodeAgent(tools=[image_generation_tool, search_tool], llm_engine=llm_engine) | ||||||
| 
 | 
 | ||||||
| # Run it! | # Run it! | ||||||
| result = agent.run( | result = agent.run( | ||||||
|  |  | ||||||
|  | @ -0,0 +1,29 @@ | ||||||
|  | from agents import stream_to_gradio, HfApiEngine, load_tool, CodeAgent | ||||||
|  | 
 | ||||||
|  | image_generation_tool = load_tool("m-ric/text-to-image") | ||||||
|  | 
 | ||||||
|  | llm_engine = HfApiEngine("Qwen/Qwen2.5-72B-Instruct") | ||||||
|  | 
 | ||||||
|  | agent = CodeAgent(tools=[image_generation_tool], llm_engine=llm_engine) | ||||||
|  | 
 | ||||||
|  | import gradio as gr | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def interact_with_agent(prompt, messages): | ||||||
|  |     messages.append(gr.ChatMessage(role="user", content=prompt)) | ||||||
|  |     yield messages | ||||||
|  |     for msg in stream_to_gradio(agent, task=prompt, reset_agent_memory=False): | ||||||
|  |         messages.append(msg) | ||||||
|  |         yield messages | ||||||
|  |     yield messages | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | with gr.Blocks() as demo: | ||||||
|  |     stored_message = gr.State([]) | ||||||
|  |     chatbot = gr.Chatbot(label="Agent", | ||||||
|  |                          type="messages", | ||||||
|  |                          avatar_images=(None, "https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png")) | ||||||
|  |     text_input = gr.Textbox(lines=1, label="Chat Message") | ||||||
|  |     text_input.submit(lambda s: (s, ""), [text_input], [stored_message, text_input]).then(interact_with_agent, [stored_message, chatbot], [chatbot]) | ||||||
|  | 
 | ||||||
|  | demo.launch() | ||||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							|  | @ -68,8 +68,12 @@ jinja2 = "^3.1.4" | ||||||
| pillow = "^11.0.0" | pillow = "^11.0.0" | ||||||
| llama-cpp-python = "^0.3.4" | llama-cpp-python = "^0.3.4" | ||||||
| markdownify = "^0.14.1" | markdownify = "^0.14.1" | ||||||
|  | gradio = "^5.8.0" | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | [tool.poetry.group.dev.dependencies] | ||||||
|  | ipykernel = "^6.29.5" | ||||||
|  | 
 | ||||||
| [build-system] | [build-system] | ||||||
| requires = ["poetry-core"] | requires = ["poetry-core"] | ||||||
| build-backend = "poetry.core.masonry.api" | build-backend = "poetry.core.masonry.api" | ||||||
		Loading…
	
		Reference in New Issue