Support gradio chatbot with continued discussion
This commit is contained in:
		
							parent
							
								
									23ab4a9df3
								
							
						
					
					
						commit
						0ada2ebc27
					
				|  | @ -144,7 +144,7 @@ class AgentImage(AgentType, ImageType): | |||
|         if self._raw is not None: | ||||
|             directory = tempfile.mkdtemp() | ||||
|             self._path = os.path.join(directory, str(uuid.uuid4()) + ".png") | ||||
|             self._raw.save(self._path) | ||||
|             self._raw.save(self._path, format="png") | ||||
|             return self._path | ||||
| 
 | ||||
|         if self._tensor is not None: | ||||
|  | @ -155,12 +155,11 @@ class AgentImage(AgentType, ImageType): | |||
| 
 | ||||
|             directory = tempfile.mkdtemp() | ||||
|             self._path = os.path.join(directory, str(uuid.uuid4()) + ".png") | ||||
| 
 | ||||
|             img.save(self._path, format="png") | ||||
| 
 | ||||
|             return self._path | ||||
| 
 | ||||
|     def save(self, output_bytes, format = None, **params): | ||||
|     def save(self, output_bytes, format : str = None, **params): | ||||
|         """ | ||||
|         Saves the image to a file. | ||||
|         Args: | ||||
|  |  | |||
|  | @ -22,7 +22,7 @@ from rich.syntax import Syntax | |||
| from transformers.utils import is_torch_available | ||||
| from .utils import console, parse_code_blob, parse_json_tool_call, truncate_content | ||||
| from .agent_types import AgentAudio, AgentImage | ||||
| from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool, setup_default_tools | ||||
| from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool | ||||
| from .llm_engine import HfApiEngine, MessageRole | ||||
| from .monitoring import Monitor | ||||
| from .prompts import ( | ||||
|  | @ -42,13 +42,11 @@ from .tools import ( | |||
|     DEFAULT_TOOL_DESCRIPTION_TEMPLATE, | ||||
|     Tool, | ||||
|     get_tool_description_with_args, | ||||
|     load_tool, | ||||
|     Toolbox, | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
| HUGGINGFACE_DEFAULT_TOOLS = {} | ||||
| _tools_are_initialized = False | ||||
| 
 | ||||
| class AgentError(Exception): | ||||
|     """Base class for other agent-related exceptions""" | ||||
|  | @ -101,9 +99,12 @@ class PlanningStep: | |||
| 
 | ||||
| @dataclass | ||||
| class TaskStep: | ||||
|     system_prompt: str | ||||
|     task: str | ||||
| 
 | ||||
| @dataclass | ||||
| class SystemPromptStep: | ||||
|     system_prompt: str | ||||
| 
 | ||||
| def format_prompt_with_tools(toolbox: Toolbox, prompt_template: str, tool_description_template: str) -> str: | ||||
|     tool_descriptions = toolbox.show_tool_descriptions(tool_description_template) | ||||
|     prompt = prompt_template.replace("{{tool_descriptions}}", tool_descriptions) | ||||
|  | @ -189,7 +190,7 @@ class BaseAgent: | |||
|             self._toolbox, self.system_prompt_template, self.tool_description_template | ||||
|         ) | ||||
|         self.system_prompt = format_prompt_with_managed_agents_descriptions(self.system_prompt, self.managed_agents) | ||||
|         self.prompt = None | ||||
|         self.prompt_messages = None | ||||
|         self.logs = [] | ||||
|         self.task = None | ||||
|         self.verbose = verbose | ||||
|  | @ -208,8 +209,7 @@ class BaseAgent: | |||
|         """Get the toolbox currently available to the agent""" | ||||
|         return self._toolbox | ||||
| 
 | ||||
|     def initialize_for_run(self): | ||||
|         self.token_count = 0 | ||||
|     def initialize_system_prompt(self): | ||||
|         self.system_prompt = format_prompt_with_tools( | ||||
|             self._toolbox, | ||||
|             self.system_prompt_template, | ||||
|  | @ -220,27 +220,25 @@ class BaseAgent: | |||
|             self.system_prompt = format_prompt_with_imports( | ||||
|                 self.system_prompt, list(set(LIST_SAFE_MODULES) | set(self.authorized_imports)) | ||||
|             ) | ||||
|         self.logs = [TaskStep(system_prompt=self.system_prompt, task=self.task)] | ||||
|         console.rule("[bold]New task", characters='=') | ||||
|         console.print(self.task) | ||||
| 
 | ||||
|         return self.system_prompt | ||||
| 
 | ||||
|     def write_inner_memory_from_logs(self, summary_mode: Optional[bool] = False) -> List[Dict[str, str]]: | ||||
|         """ | ||||
|         Reads past llm_outputs, actions, and observations or errors from the logs into a series of messages | ||||
|         that can be used as input to the LLM. | ||||
|         """ | ||||
|         prompt_message = {"role": MessageRole.SYSTEM, "content": self.logs[0].system_prompt} | ||||
|         task_message = { | ||||
|             "role": MessageRole.USER, | ||||
|             "content": "Task: " + self.logs[0].task, | ||||
|         } | ||||
|         if summary_mode: | ||||
|             memory = [task_message] | ||||
|         else: | ||||
|             memory = [prompt_message, task_message] | ||||
|         for i, step_log in enumerate(self.logs[1:]): | ||||
|         memory = [] | ||||
|         for i, step_log in enumerate(self.logs): | ||||
|             if isinstance(step_log, SystemPromptStep): | ||||
|                 if not summary_mode: | ||||
|                     thought_message = { | ||||
|                         "role": MessageRole.SYSTEM, | ||||
|                         "content": step_log.system_prompt.strip(), | ||||
|                     } | ||||
|                     memory.append(thought_message) | ||||
| 
 | ||||
|             if isinstance(step_log, PlanningStep): | ||||
|             elif isinstance(step_log, PlanningStep): | ||||
|                 thought_message = { | ||||
|                     "role": MessageRole.ASSISTANT, | ||||
|                     "content": "[FACTS LIST]:\n" + step_log.facts.strip(), | ||||
|  | @ -398,21 +396,21 @@ class ReactAgent(BaseAgent): | |||
|         """ | ||||
|         This method provides a final answer to the task, based on the logs of the agent's interactions. | ||||
|         """ | ||||
|         self.prompt = [ | ||||
|         self.prompt_messages = [ | ||||
|             { | ||||
|                 "role": MessageRole.SYSTEM, | ||||
|                 "content": "An agent tried to answer a user query but it got stuck and failed to do so. You are tasked with providing an answer instead. Here is the agent's memory:", | ||||
|             } | ||||
|         ] | ||||
|         self.prompt += self.write_inner_memory_from_logs()[1:] | ||||
|         self.prompt += [ | ||||
|         self.prompt_messages += self.write_inner_memory_from_logs()[1:] | ||||
|         self.prompt_messages += [ | ||||
|             { | ||||
|                 "role": MessageRole.USER, | ||||
|                 "content": f"Based on the above, please provide an answer to the following user request:\n{task}", | ||||
|             } | ||||
|         ] | ||||
|         try: | ||||
|             return self.llm_engine(self.prompt) | ||||
|             return self.llm_engine(self.prompt_messages) | ||||
|         except Exception as e: | ||||
|             error_msg = f"Error in generating final LLM output: {e}." | ||||
|             console.print(f"[bold red]{error_msg}[/bold red]") | ||||
|  | @ -423,7 +421,10 @@ class ReactAgent(BaseAgent): | |||
|         Runs the agent for the given task. | ||||
| 
 | ||||
|         Args: | ||||
|             task (`str`): The task to perform | ||||
|             task (`str`): The task to perform. | ||||
|             stream (`bool`): Wether to run in a streaming way. | ||||
|             reset (`bool`): Wether to reset the conversation or keep it going from previous run. | ||||
|             oneshot (`bool`): Should the agent run in one shot or multi-step fashion? | ||||
| 
 | ||||
|         Example: | ||||
|         ```py | ||||
|  | @ -436,10 +437,23 @@ class ReactAgent(BaseAgent): | |||
|         if len(kwargs) > 0: | ||||
|             self.task += f"\nYou have been provided with these initial arguments: {str(kwargs)}." | ||||
|         self.state = kwargs.copy() | ||||
| 
 | ||||
|         self.initialize_system_prompt() | ||||
|         system_prompt_step = SystemPromptStep(system_prompt=self.system_prompt) | ||||
| 
 | ||||
|         if reset: | ||||
|             self.initialize_for_run() | ||||
|             self.token_count = 0 | ||||
|             self.logs = [] | ||||
|             self.logs.append(system_prompt_step) | ||||
|         else: | ||||
|             self.logs.append(TaskStep(task=task)) | ||||
|             if len(self.logs) > 0: | ||||
|                 self.logs[0] = system_prompt_step | ||||
|             else: | ||||
|                 self.logs.append(system_prompt_step) | ||||
| 
 | ||||
|         console.rule("[bold]New task", characters='=') | ||||
|         console.print(self.task) | ||||
|         self.logs.append(TaskStep(task=task)) | ||||
| 
 | ||||
|         if oneshot: | ||||
|             step_start_time = time.time() | ||||
|  | @ -676,20 +690,20 @@ class JsonAgent(ReactAgent): | |||
|         """ | ||||
|         agent_memory = self.write_inner_memory_from_logs() | ||||
| 
 | ||||
|         self.prompt = agent_memory | ||||
|         self.prompt_messages = agent_memory | ||||
| 
 | ||||
|         # Add new step in logs | ||||
|         log_entry.agent_memory = agent_memory.copy() | ||||
| 
 | ||||
|         if self.verbose: | ||||
|             console.rule("[italic]Calling LLM engine with this last message:", align="left") | ||||
|             console.print(self.prompt[-1]) | ||||
|             console.print(self.prompt_messages[-1]) | ||||
|             console.rule() | ||||
| 
 | ||||
|         try: | ||||
|             additional_args = {"grammar": self.grammar} if self.grammar is not None else {} | ||||
|             llm_output = self.llm_engine( | ||||
|                 self.prompt, stop_sequences=["<end_action>", "Observation:"], **additional_args | ||||
|                 self.prompt_messages, stop_sequences=["<end_action>", "Observation:"], **additional_args | ||||
|             ) | ||||
|             log_entry.llm_output = llm_output | ||||
|         except Exception as e: | ||||
|  | @ -796,20 +810,20 @@ class CodeAgent(ReactAgent): | |||
|         """ | ||||
|         agent_memory = self.write_inner_memory_from_logs() | ||||
| 
 | ||||
|         self.prompt = agent_memory.copy() | ||||
|         self.prompt_messages = agent_memory.copy() | ||||
| 
 | ||||
|         # Add new step in logs | ||||
|         log_entry.agent_memory = agent_memory.copy() | ||||
| 
 | ||||
|         if self.verbose: | ||||
|             console.rule("[italic]Calling LLM engine with these last messages:", align="left") | ||||
|             console.print(self.prompt[-2:]) | ||||
|             console.print(self.prompt_messages[-2:]) | ||||
|             console.rule() | ||||
| 
 | ||||
|         try: | ||||
|             additional_args = {"grammar": self.grammar} if self.grammar is not None else {} | ||||
|             llm_output = self.llm_engine( | ||||
|                 self.prompt, stop_sequences=["<end_action>", "Observation:"], **additional_args | ||||
|                 self.prompt_messages, stop_sequences=["<end_action>", "Observation:"], **additional_args | ||||
|             ) | ||||
|             log_entry.llm_output = llm_output | ||||
|         except Exception as e: | ||||
|  | @ -893,7 +907,7 @@ You have been submitted this task by your manager. | |||
| Task: | ||||
| {task} | ||||
| --- | ||||
| You're helping your manager solve a wider task: so make sure to not provide a one-line answer, but give as much information as possible so that they have a clear understanding of the answer. | ||||
| You're helping your manager solve a wider task: so make sure to not provide a one-line answer, but give as much information as possible to give them a clear understanding of the answer. | ||||
| 
 | ||||
| Your final_answer WILL HAVE to contain these parts: | ||||
| ### 1. Task outcome (short version): | ||||
|  |  | |||
|  | @ -14,7 +14,6 @@ | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| import importlib.util | ||||
| import json | ||||
| import math | ||||
| from dataclasses import dataclass | ||||
|  | @ -25,7 +24,7 @@ from huggingface_hub import hf_hub_download, list_spaces | |||
| 
 | ||||
| from transformers.utils import is_offline_mode | ||||
| from .python_interpreter import LIST_SAFE_MODULES, evaluate_python_code | ||||
| from .tools import TOOL_CONFIG_FILE, TOOL_MAPPING, Tool | ||||
| from .tools import TOOL_CONFIG_FILE, Tool | ||||
| 
 | ||||
| 
 | ||||
| def custom_print(*args): | ||||
|  | @ -97,12 +96,6 @@ class PreTool: | |||
|     repo_id: str | ||||
| 
 | ||||
| 
 | ||||
| HUGGINGFACE_DEFAULT_TOOLS_FROM_HUB = [ | ||||
|     "image-transformation", | ||||
|     "text-to-image", | ||||
| ] | ||||
| 
 | ||||
| 
 | ||||
| def get_remote_tools(logger, organization="huggingface-tools"): | ||||
|     if is_offline_mode(): | ||||
|         logger.info("You are in offline mode, so remote tools are not available.") | ||||
|  | @ -128,26 +121,6 @@ def get_remote_tools(logger, organization="huggingface-tools"): | |||
|     return tools | ||||
| 
 | ||||
| 
 | ||||
| def setup_default_tools(): | ||||
|     default_tools = {} | ||||
|     main_module = importlib.import_module("transformers") | ||||
|     tools_module = main_module.agents | ||||
| 
 | ||||
|     for task_name, tool_class_name in TOOL_MAPPING.items(): | ||||
|         tool_class = getattr(tools_module, tool_class_name) | ||||
|         tool_instance = tool_class() | ||||
|         default_tools[tool_class.name] = PreTool( | ||||
|             name=tool_instance.name, | ||||
|             inputs=tool_instance.inputs, | ||||
|             output_type=tool_instance.output_type, | ||||
|             task=task_name, | ||||
|             description=tool_instance.description, | ||||
|             repo_id=None, | ||||
|         ) | ||||
| 
 | ||||
|     return default_tools | ||||
| 
 | ||||
| 
 | ||||
| class PythonInterpreterTool(Tool): | ||||
|     name = "python_interpreter" | ||||
|     description = "This is a tool that evaluates python code. It can be used to perform calculations." | ||||
|  |  | |||
|  | @ -1,88 +0,0 @@ | |||
| #!/usr/bin/env python | ||||
| # coding=utf-8 | ||||
| 
 | ||||
| # Copyright 2023 The HuggingFace Inc. team. All rights reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| import re | ||||
| 
 | ||||
| import numpy as np | ||||
| import torch | ||||
| 
 | ||||
| from transformers import AutoProcessor, VisionEncoderDecoderModel | ||||
| from transformers.utils import is_vision_available | ||||
| from .tools import PipelineTool | ||||
| 
 | ||||
| 
 | ||||
| if is_vision_available(): | ||||
|     from PIL import Image | ||||
| 
 | ||||
| 
 | ||||
| class DocumentQuestionAnsweringTool(PipelineTool): | ||||
|     default_checkpoint = "naver-clova-ix/donut-base-finetuned-docvqa" | ||||
|     description = "This is a tool that answers a question about an document (pdf). It returns a string that contains the answer to the question." | ||||
|     name = "document_qa" | ||||
|     pre_processor_class = AutoProcessor | ||||
|     model_class = VisionEncoderDecoderModel | ||||
| 
 | ||||
|     inputs = { | ||||
|         "document": { | ||||
|             "type": "image", | ||||
|             "description": "The image containing the information. Can be a PIL Image or a string path to the image.", | ||||
|         }, | ||||
|         "question": {"type": "string", "description": "The question in English"}, | ||||
|     } | ||||
|     output_type = "string" | ||||
| 
 | ||||
|     def __init__(self, *args, **kwargs): | ||||
|         if not is_vision_available(): | ||||
|             raise ValueError("Pillow must be installed to use the DocumentQuestionAnsweringTool.") | ||||
| 
 | ||||
|         super().__init__(*args, **kwargs) | ||||
| 
 | ||||
|     def encode(self, document: "Image", question: str): | ||||
|         task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>" | ||||
|         prompt = task_prompt.replace("{user_input}", question) | ||||
|         decoder_input_ids = self.pre_processor.tokenizer( | ||||
|             prompt, add_special_tokens=False, return_tensors="pt" | ||||
|         ).input_ids | ||||
|         if isinstance(document, str): | ||||
|             img = Image.open(document).convert("RGB") | ||||
|             img_array = np.array(img).transpose(2, 0, 1) | ||||
|             document = torch.from_numpy(img_array) | ||||
|         pixel_values = self.pre_processor(document, return_tensors="pt").pixel_values | ||||
| 
 | ||||
|         return {"decoder_input_ids": decoder_input_ids, "pixel_values": pixel_values} | ||||
| 
 | ||||
|     def forward(self, inputs): | ||||
|         return self.model.generate( | ||||
|             inputs["pixel_values"].to(self.device), | ||||
|             decoder_input_ids=inputs["decoder_input_ids"].to(self.device), | ||||
|             max_length=self.model.decoder.config.max_position_embeddings, | ||||
|             early_stopping=True, | ||||
|             pad_token_id=self.pre_processor.tokenizer.pad_token_id, | ||||
|             eos_token_id=self.pre_processor.tokenizer.eos_token_id, | ||||
|             use_cache=True, | ||||
|             num_beams=1, | ||||
|             bad_words_ids=[[self.pre_processor.tokenizer.unk_token_id]], | ||||
|             return_dict_in_generate=True, | ||||
|         ).sequences | ||||
| 
 | ||||
|     def decode(self, outputs): | ||||
|         sequence = self.pre_processor.batch_decode(outputs)[0] | ||||
|         sequence = sequence.replace(self.pre_processor.tokenizer.eos_token, "") | ||||
|         sequence = sequence.replace(self.pre_processor.tokenizer.pad_token, "") | ||||
|         sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()  # remove first task start token | ||||
|         sequence = self.pre_processor.token2json(sequence) | ||||
| 
 | ||||
|         return sequence["answer"] | ||||
|  | @ -1,58 +0,0 @@ | |||
| #!/usr/bin/env python | ||||
| # coding=utf-8 | ||||
| 
 | ||||
| # Copyright 2023 The HuggingFace Inc. team. All rights reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| import torch | ||||
| from PIL import Image | ||||
| 
 | ||||
| from transformers import AutoModelForVisualQuestionAnswering, AutoProcessor | ||||
| from transformers.utils import requires_backends | ||||
| from .tools import PipelineTool | ||||
| 
 | ||||
| 
 | ||||
| class ImageQuestionAnsweringTool(PipelineTool): | ||||
|     default_checkpoint = "dandelin/vilt-b32-finetuned-vqa" | ||||
|     description = ( | ||||
|         "This is a tool that answers a question about an image. It " | ||||
|         "returns a text that is the answer to the question." | ||||
|     ) | ||||
|     name = "image_qa" | ||||
|     pre_processor_class = AutoProcessor | ||||
|     model_class = AutoModelForVisualQuestionAnswering | ||||
| 
 | ||||
|     inputs = { | ||||
|         "image": { | ||||
|             "type": "image", | ||||
|             "description": "The image containing the information. Can be a PIL Image or a string path to the image.", | ||||
|         }, | ||||
|         "question": {"type": "string", "description": "The question in English"}, | ||||
|     } | ||||
|     output_type = "string" | ||||
| 
 | ||||
|     def __init__(self, *args, **kwargs): | ||||
|         requires_backends(self, ["vision"]) | ||||
|         super().__init__(*args, **kwargs) | ||||
| 
 | ||||
|     def encode(self, image: "Image", question: str): | ||||
|         return self.pre_processor(image, question, return_tensors="pt") | ||||
| 
 | ||||
|     def forward(self, inputs): | ||||
|         with torch.no_grad(): | ||||
|             return self.model(**inputs).logits | ||||
| 
 | ||||
|     def decode(self, outputs): | ||||
|         idx = outputs.argmax(-1).item() | ||||
|         return self.model.config.id2label[idx] | ||||
|  | @ -17,20 +17,8 @@ | |||
| from .agent_types import AgentAudio, AgentImage, AgentText | ||||
| from .utils import console | ||||
| 
 | ||||
| 
 | ||||
| def pull_message(step_log: dict, test_mode: bool = True): | ||||
|     try: | ||||
|         from gradio import ChatMessage | ||||
|     except ImportError: | ||||
|         if test_mode: | ||||
| 
 | ||||
|             class ChatMessage: | ||||
|                 def __init__(self, role, content, metadata=None): | ||||
|                     self.role = role | ||||
|                     self.content = content | ||||
|                     self.metadata = metadata | ||||
|         else: | ||||
|             raise ImportError("Gradio should be installed in order to launch a gradio demo.") | ||||
|     from gradio import ChatMessage | ||||
| 
 | ||||
|     if step_log.get("rationale"): | ||||
|         yield ChatMessage(role="assistant", content=step_log["rationale"]) | ||||
|  | @ -54,23 +42,11 @@ def pull_message(step_log: dict, test_mode: bool = True): | |||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| def stream_to_gradio(agent, task: str, test_mode: bool = False, **kwargs): | ||||
| def stream_to_gradio(agent, task: str, test_mode: bool = False, reset_agent_memory: bool=False, **kwargs): | ||||
|     """Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages.""" | ||||
|     from gradio import ChatMessage | ||||
| 
 | ||||
|     try: | ||||
|         from gradio import ChatMessage | ||||
|     except ImportError: | ||||
|         if test_mode: | ||||
| 
 | ||||
|             class ChatMessage: | ||||
|                 def __init__(self, role, content, metadata=None): | ||||
|                     self.role = role | ||||
|                     self.content = content | ||||
|                     self.metadata = metadata | ||||
|         else: | ||||
|             raise ImportError("Gradio should be installed in order to launch a gradio demo.") | ||||
| 
 | ||||
|     for step_log in agent.run(task, stream=True, **kwargs): | ||||
|     for step_log in agent.run(task, stream=True, reset=reset_agent_memory, **kwargs): | ||||
|         if isinstance(step_log, dict): | ||||
|             for message in pull_message(step_log, test_mode=test_mode): | ||||
|                 yield message | ||||
|  |  | |||
|  | @ -395,7 +395,7 @@ Do not add anything else.""" | |||
| SYSTEM_PROMPT_PLAN = """You are a world expert at making efficient plans to solve any task using a set of carefully crafted tools. | ||||
| 
 | ||||
| Now for the given task, develop a step-by-step high-level plan taking into account the above inputs and list of facts. | ||||
| This plan should involve individual tasks based on the avilable tools, that if executed correctly will yield the correct answer. | ||||
| This plan should involve individual tasks based on the available tools, that if executed correctly will yield the correct answer. | ||||
| Do not skip steps, do not add any superfluous steps. Only write the high-level plan, DO NOT DETAIL INDIVIDUAL TOOL CALLS. | ||||
| After writing the final step of the plan, write the '\n<end_plan>' tag and stop there.""" | ||||
| 
 | ||||
|  | @ -466,7 +466,7 @@ Here is the up to date list of facts that you know: | |||
| ``` | ||||
| 
 | ||||
| Now for the given task, develop a step-by-step high-level plan taking into account the above inputs and list of facts. | ||||
| This plan should involve individual tasks based on the avilable tools, that if executed correctly will yield the correct answer. | ||||
| This plan should involve individual tasks based on the available tools, that if executed correctly will yield the correct answer. | ||||
| Beware that you have {remaining_steps} steps remaining. | ||||
| Do not skip steps, do not add any superfluous steps. Only write the high-level plan, DO NOT DETAIL INDIVIDUAL TOOL CALLS. | ||||
| After writing the final step of the plan, write the '\n<end_plan>' tag and stop there. | ||||
|  |  | |||
|  | @ -1,39 +0,0 @@ | |||
| #!/usr/bin/env python | ||||
| # coding=utf-8 | ||||
| 
 | ||||
| # Copyright 2024 The HuggingFace Inc. team. All rights reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| from transformers import WhisperForConditionalGeneration, WhisperProcessor | ||||
| from .tools import PipelineTool | ||||
| 
 | ||||
| 
 | ||||
| class SpeechToTextTool(PipelineTool): | ||||
|     default_checkpoint = "distil-whisper/distil-large-v3" | ||||
|     description = "This is a tool that transcribes an audio into text. It returns the transcribed text." | ||||
|     name = "transcriber" | ||||
|     pre_processor_class = WhisperProcessor | ||||
|     model_class = WhisperForConditionalGeneration | ||||
| 
 | ||||
|     inputs = {"audio": {"type": "audio", "description": "The audio to transcribe"}} | ||||
|     output_type = "string" | ||||
| 
 | ||||
|     def encode(self, audio): | ||||
|         return self.pre_processor(audio, return_tensors="pt") | ||||
| 
 | ||||
|     def forward(self, inputs): | ||||
|         return self.model.generate(inputs["input_features"]) | ||||
| 
 | ||||
|     def decode(self, outputs): | ||||
|         return self.pre_processor.batch_decode(outputs, skip_special_tokens=True)[0] | ||||
|  | @ -1,67 +0,0 @@ | |||
| #!/usr/bin/env python | ||||
| # coding=utf-8 | ||||
| 
 | ||||
| # Copyright 2024 The HuggingFace Inc. team. All rights reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| import torch | ||||
| 
 | ||||
| from transformers import SpeechT5ForTextToSpeech, SpeechT5HifiGan, SpeechT5Processor | ||||
| from transformers.utils import is_datasets_available | ||||
| from .tools import PipelineTool | ||||
| 
 | ||||
| 
 | ||||
| if is_datasets_available(): | ||||
|     from datasets import load_dataset | ||||
| 
 | ||||
| 
 | ||||
| class TextToSpeechTool(PipelineTool): | ||||
|     default_checkpoint = "microsoft/speecht5_tts" | ||||
|     description = ( | ||||
|         "This is a tool that reads an English text out loud. It returns a waveform object containing the sound." | ||||
|     ) | ||||
|     name = "text_to_speech" | ||||
|     pre_processor_class = SpeechT5Processor | ||||
|     model_class = SpeechT5ForTextToSpeech | ||||
|     post_processor_class = SpeechT5HifiGan | ||||
| 
 | ||||
|     inputs = {"text": {"type": "string", "description": "The text to read out loud (in English)"}} | ||||
|     output_type = "audio" | ||||
| 
 | ||||
|     def setup(self): | ||||
|         if self.post_processor is None: | ||||
|             self.post_processor = "microsoft/speecht5_hifigan" | ||||
|         super().setup() | ||||
| 
 | ||||
|     def encode(self, text, speaker_embeddings=None): | ||||
|         inputs = self.pre_processor(text=text, return_tensors="pt", truncation=True) | ||||
| 
 | ||||
|         if speaker_embeddings is None: | ||||
|             if not is_datasets_available(): | ||||
|                 raise ImportError("Datasets needs to be installed if not passing speaker embeddings.") | ||||
| 
 | ||||
|             embeddings_dataset = load_dataset( | ||||
|                 "Matthijs/cmu-arctic-xvectors", split="validation", trust_remote_code=True | ||||
|             ) | ||||
|             speaker_embeddings = torch.tensor(embeddings_dataset[7305]["xvector"]).unsqueeze(0) | ||||
| 
 | ||||
|         return {"input_ids": inputs["input_ids"], "speaker_embeddings": speaker_embeddings} | ||||
| 
 | ||||
|     def forward(self, inputs): | ||||
|         with torch.no_grad(): | ||||
|             return self.model.generate_speech(**inputs) | ||||
| 
 | ||||
|     def decode(self, outputs): | ||||
|         with torch.no_grad(): | ||||
|             return self.post_processor(outputs).cpu().detach() | ||||
|  | @ -80,6 +80,19 @@ def get_repo_type(repo_id, repo_type=None, **hub_kwargs): | |||
|         return "space" | ||||
| 
 | ||||
| 
 | ||||
| def setup_default_tools(): | ||||
|     default_tools = {} | ||||
|     main_module = importlib.import_module("transformers") | ||||
|     tools_module = main_module.agents | ||||
| 
 | ||||
|     for task_name, tool_class_name in TOOL_MAPPING.items(): | ||||
|         tool_class = getattr(tools_module, tool_class_name) | ||||
|         tool_instance = tool_class() | ||||
|         default_tools[tool_class.name] = tool_instance | ||||
| 
 | ||||
|     return default_tools | ||||
| 
 | ||||
| 
 | ||||
| # docstyle-ignore | ||||
| APP_FILE_TEMPLATE = """from transformers import launch_gradio_demo | ||||
| from {module_name} import {class_name} | ||||
|  | @ -811,11 +824,6 @@ def launch_gradio_demo(tool_class: Tool): | |||
| 
 | ||||
| 
 | ||||
| TOOL_MAPPING = { | ||||
|     "document_question_answering": "DocumentQuestionAnsweringTool", | ||||
|     "image_question_answering": "ImageQuestionAnsweringTool", | ||||
|     "speech_to_text": "SpeechToTextTool", | ||||
|     "text_to_speech": "TextToSpeechTool", | ||||
|     "translation": "TranslationTool", | ||||
|     "python_interpreter": "PythonInterpreterTool", | ||||
|     "web_search": "DuckDuckGoSearchTool", | ||||
| } | ||||
|  | @ -1018,18 +1026,14 @@ class Toolbox: | |||
|         self._tools = {tool.name: tool for tool in tools} | ||||
|         if add_base_tools: | ||||
|             self.add_base_tools() | ||||
|         # self._load_tools_if_needed() | ||||
| 
 | ||||
|     def add_base_tools(self, add_python_interpreter: bool = False): | ||||
|         global _tools_are_initialized | ||||
|         global HUGGINGFACE_DEFAULT_TOOLS | ||||
|         if not _tools_are_initialized: | ||||
|         if len(HUGGINGFACE_DEFAULT_TOOLS.keys()) == 0: | ||||
|             HUGGINGFACE_DEFAULT_TOOLS = setup_default_tools() | ||||
|             _tools_are_initialized = True | ||||
|         for tool in HUGGINGFACE_DEFAULT_TOOLS.values(): | ||||
|             if tool.name != "python_interpreter" or add_python_interpreter: | ||||
|                 self.add_tool(tool) | ||||
|         # self._load_tools_if_needed() | ||||
| 
 | ||||
|     @property | ||||
|     def tools(self) -> Dict[str, Tool]: | ||||
|  |  | |||
|  | @ -1,4 +1,4 @@ | |||
| from agents import load_tool, ReactCodeAgent, HfApiEngine | ||||
| from agents import load_tool, CodeAgent, HfApiEngine | ||||
| 
 | ||||
| # Import tool from Hub | ||||
| image_generation_tool = load_tool("m-ric/text-to-image", cache=False) | ||||
|  | @ -10,7 +10,7 @@ search_tool = DuckDuckGoSearchTool() | |||
| 
 | ||||
| llm_engine = HfApiEngine("Qwen/Qwen2.5-72B-Instruct") | ||||
| # Initialize the agent with both tools | ||||
| agent = ReactCodeAgent(tools=[image_generation_tool, search_tool], llm_engine=llm_engine) | ||||
| agent = CodeAgent(tools=[image_generation_tool, search_tool], llm_engine=llm_engine) | ||||
| 
 | ||||
| # Run it! | ||||
| result = agent.run( | ||||
|  |  | |||
|  | @ -0,0 +1,29 @@ | |||
| from agents import stream_to_gradio, HfApiEngine, load_tool, CodeAgent | ||||
| 
 | ||||
| image_generation_tool = load_tool("m-ric/text-to-image") | ||||
| 
 | ||||
| llm_engine = HfApiEngine("Qwen/Qwen2.5-72B-Instruct") | ||||
| 
 | ||||
| agent = CodeAgent(tools=[image_generation_tool], llm_engine=llm_engine) | ||||
| 
 | ||||
| import gradio as gr | ||||
| 
 | ||||
| 
 | ||||
| def interact_with_agent(prompt, messages): | ||||
|     messages.append(gr.ChatMessage(role="user", content=prompt)) | ||||
|     yield messages | ||||
|     for msg in stream_to_gradio(agent, task=prompt, reset_agent_memory=False): | ||||
|         messages.append(msg) | ||||
|         yield messages | ||||
|     yield messages | ||||
| 
 | ||||
| 
 | ||||
| with gr.Blocks() as demo: | ||||
|     stored_message = gr.State([]) | ||||
|     chatbot = gr.Chatbot(label="Agent", | ||||
|                          type="messages", | ||||
|                          avatar_images=(None, "https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png")) | ||||
|     text_input = gr.Textbox(lines=1, label="Chat Message") | ||||
|     text_input.submit(lambda s: (s, ""), [text_input], [stored_message, text_input]).then(interact_with_agent, [stored_message, chatbot], [chatbot]) | ||||
| 
 | ||||
| demo.launch() | ||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							|  | @ -68,8 +68,12 @@ jinja2 = "^3.1.4" | |||
| pillow = "^11.0.0" | ||||
| llama-cpp-python = "^0.3.4" | ||||
| markdownify = "^0.14.1" | ||||
| gradio = "^5.8.0" | ||||
| 
 | ||||
| 
 | ||||
| [tool.poetry.group.dev.dependencies] | ||||
| ipykernel = "^6.29.5" | ||||
| 
 | ||||
| [build-system] | ||||
| requires = ["poetry-core"] | ||||
| build-backend = "poetry.core.masonry.api" | ||||
		Loading…
	
		Reference in New Issue