diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..897944e
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,25 @@
+# Logging
+logs
+tmp
+wandb
+
+# Data
+data
+outputs
+
+# Apple
+.DS_Store
+
+# VS Code
+.vscode
+
+# Environments
+.env
+.venv
+env/
+venv/
+env.bak/
+venv.bak/
+
+# Jupyter Notebook
+.ipynb_checkpoints
\ No newline at end of file
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..fc64f02
--- /dev/null
+++ b/README.md
@@ -0,0 +1,277 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Run your *raw* PyTorch training script on any kind of device
+
+
+
+
+
+
+## Easy to integrate
+
+๐ค Accelerate was created for PyTorch users who like to write the training loop of PyTorch models but are reluctant to write and maintain the boilerplate code needed to use multi-GPUs/TPU/fp16.
+
+๐ค Accelerate abstracts exactly and only the boilerplate code related to multi-GPUs/TPU/fp16 and leaves the rest of your code unchanged.
+
+Here is an example:
+
+```diff
+ import torch
+ import torch.nn.functional as F
+ from datasets import load_dataset
++ from accelerate import Accelerator
+
++ accelerator = Accelerator()
+- device = 'cpu'
++ device = accelerator.device
+
+ model = torch.nn.Transformer().to(device)
+ optimizer = torch.optim.Adam(model.parameters())
+
+ dataset = load_dataset('my_dataset')
+ data = torch.utils.data.DataLoader(dataset, shuffle=True)
+
++ model, optimizer, data = accelerator.prepare(model, optimizer, data)
+
+ model.train()
+ for epoch in range(10):
+ for source, targets in data:
+ source = source.to(device)
+ targets = targets.to(device)
+
+ optimizer.zero_grad()
+
+ output = model(source)
+ loss = F.cross_entropy(output, targets)
+
+- loss.backward()
++ accelerator.backward(loss)
+
+ optimizer.step()
+```
+
+As you can see in this example, by adding 5-lines to any standard PyTorch training script you can now run on any kind of single or distributed node setting (single CPU, single GPU, multi-GPUs and TPUs) as well as with or without mixed precision (fp8, fp16, bf16).
+
+In particular, the same code can then be run without modification on your local machine for debugging or your training environment.
+
+๐ค Accelerate even handles the device placement for you (which requires a few more changes to your code, but is safer in general), so you can even simplify your training loop further:
+
+```diff
+ import torch
+ import torch.nn.functional as F
+ from datasets import load_dataset
++ from accelerate import Accelerator
+
+- device = 'cpu'
++ accelerator = Accelerator()
+
+- model = torch.nn.Transformer().to(device)
++ model = torch.nn.Transformer()
+ optimizer = torch.optim.Adam(model.parameters())
+
+ dataset = load_dataset('my_dataset')
+ data = torch.utils.data.DataLoader(dataset, shuffle=True)
+
++ model, optimizer, data = accelerator.prepare(model, optimizer, data)
+
+ model.train()
+ for epoch in range(10):
+ for source, targets in data:
+- source = source.to(device)
+- targets = targets.to(device)
+
+ optimizer.zero_grad()
+
+ output = model(source)
+ loss = F.cross_entropy(output, targets)
+
+- loss.backward()
++ accelerator.backward(loss)
+
+ optimizer.step()
+```
+
+Want to learn more? Check out the [documentation](https://huggingface.co/docs/accelerate) or have a look at our [examples](https://github.com/huggingface/accelerate/tree/main/examples).
+
+## Launching script
+
+๐ค Accelerate also provides an optional CLI tool that allows you to quickly configure and test your training environment before launching the scripts. No need to remember how to use `torch.distributed.run` or to write a specific launcher for TPU training!
+On your machine(s) just run:
+
+```bash
+accelerate config
+```
+
+and answer the questions asked. This will generate a config file that will be used automatically to properly set the default options when doing
+
+```bash
+accelerate launch my_script.py --args_to_my_script
+```
+
+For instance, here is how you would run the GLUE example on the MRPC task (from the root of the repo):
+
+```bash
+accelerate launch examples/nlp_example.py
+```
+
+This CLI tool is **optional**, and you can still use `python my_script.py` or `python -m torchrun my_script.py` at your convenience.
+
+You can also directly pass in the arguments you would to `torchrun` as arguments to `accelerate launch` if you wish to not run` accelerate config`.
+
+For example, here is how to launch on two GPUs:
+
+```bash
+accelerate launch --multi_gpu --num_processes 2 examples/nlp_example.py
+```
+
+To learn more, check the CLI documentation available [here](https://huggingface.co/docs/accelerate/package_reference/cli).
+
+Or view the configuration zoo [here](https://github.com/huggingface/accelerate/blob/main/examples/config_yaml_templates/)
+
+## Launching multi-CPU run using MPI
+
+๐ค Here is another way to launch multi-CPU run using MPI. You can learn how to install Open MPI on [this page](https://www.open-mpi.org/faq/?category=building#easy-build). You can use Intel MPI or MVAPICH as well.
+Once you have MPI setup on your cluster, just run:
+```bash
+accelerate config
+```
+Answer the questions that are asked, selecting to run using multi-CPU, and answer "yes" when asked if you want accelerate to launch mpirun.
+Then, use `accelerate launch` with your script like:
+```bash
+accelerate launch examples/nlp_example.py
+```
+Alternatively, you can use mpirun directly, without using the CLI like:
+```bash
+mpirun -np 2 python examples/nlp_example.py
+```
+
+## Launching training using DeepSpeed
+
+๐ค Accelerate supports training on single/multiple GPUs using DeepSpeed. To use it, you don't need to change anything in your training code; you can set everything using just `accelerate config`. However, if you desire to tweak your DeepSpeed related args from your Python script, we provide you the `DeepSpeedPlugin`.
+
+```python
+from accelerate import Accelerator, DeepSpeedPlugin
+
+# deepspeed needs to know your gradient accumulation steps beforehand, so don't forget to pass it
+# Remember you still need to do gradient accumulation by yourself, just like you would have done without deepspeed
+deepspeed_plugin = DeepSpeedPlugin(zero_stage=2, gradient_accumulation_steps=2)
+accelerator = Accelerator(mixed_precision='fp16', deepspeed_plugin=deepspeed_plugin)
+
+# How to save your ๐ค Transformer?
+accelerator.wait_for_everyone()
+unwrapped_model = accelerator.unwrap_model(model)
+unwrapped_model.save_pretrained(save_dir, save_function=accelerator.save, state_dict=accelerator.get_state_dict(model))
+```
+
+Note: DeepSpeed support is experimental for now. In case you get into some problem, please open an issue.
+
+## Launching your training from a notebook
+
+๐ค Accelerate also provides a `notebook_launcher` function you can use in a notebook to launch a distributed training. This is especially useful for Colab or Kaggle notebooks with a TPU backend. Just define your training loop in a `training_function` then in your last cell, add:
+
+```python
+from accelerate import notebook_launcher
+
+notebook_launcher(training_function)
+```
+
+An example can be found in [this notebook](https://github.com/huggingface/notebooks/blob/main/examples/accelerate_examples/simple_nlp_example.ipynb). [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/accelerate_examples/simple_nlp_example.ipynb)
+
+## Why should I use ๐ค Accelerate?
+
+You should use ๐ค Accelerate when you want to easily run your training scripts in a distributed environment without having to renounce full control over your training loop. This is not a high-level framework above PyTorch, just a thin wrapper so you don't have to learn a new library. In fact, the whole API of ๐ค Accelerate is in one class, the `Accelerator` object.
+
+## Why shouldn't I use ๐ค Accelerate?
+
+You shouldn't use ๐ค Accelerate if you don't want to write a training loop yourself. There are plenty of high-level libraries above PyTorch that will offer you that, ๐ค Accelerate is not one of them.
+
+## Frameworks using ๐ค Accelerate
+
+If you like the simplicity of ๐ค Accelerate but would prefer a higher-level abstraction around its capabilities, some frameworks and libraries that are built on top of ๐ค Accelerate are listed below:
+
+* [Amphion](https://github.com/open-mmlab/Amphion) is a toolkit for Audio, Music, and Speech Generation. Its purpose is to support reproducible research and help junior researchers and engineers get started in the field of audio, music, and speech generation research and development.
+* [Animus](https://github.com/Scitator/animus) is a minimalistic framework to run machine learning experiments. Animus highlights common "breakpoints" in ML experiments and provides a unified interface for them within [IExperiment](https://github.com/Scitator/animus/blob/main/animus/core.py#L76).
+* [Catalyst](https://github.com/catalyst-team/catalyst#getting-started) is a PyTorch framework for Deep Learning Research and Development. It focuses on reproducibility, rapid experimentation, and codebase reuse so you can create something new rather than write yet another train loop. Catalyst provides a [Runner](https://catalyst-team.github.io/catalyst/api/core.html#runner) to connect all parts of the experiment: hardware backend, data transformations, model training, and inference logic.
+* [fastai](https://github.com/fastai/fastai#installing) is a PyTorch framework for Deep Learning that simplifies training fast and accurate neural nets using modern best practices. fastai provides a [Learner](https://docs.fast.ai/learner.html#Learner) to handle the training, fine-tuning, and inference of deep learning algorithms.
+* [Finetuner](https://github.com/jina-ai/finetuner) is a service that enables models to create higher-quality embeddings for semantic search, visual similarity search, cross-modal text<->image search, recommendation systems, clustering, duplication detection, anomaly detection, or other uses.
+* [InvokeAI](https://github.com/invoke-ai/InvokeAI) is a creative engine for Stable Diffusion models, offering industry-leading WebUI, terminal usage support, and serves as the foundation for many commercial products.
+* [Kornia](https://kornia.readthedocs.io/en/latest/get-started/introduction.html) is a differentiable library that allows classical computer vision to be integrated into deep learning models. Kornia provides a [Trainer](https://kornia.readthedocs.io/en/latest/x.html#kornia.x.Trainer) with the specific purpose to train and fine-tune the supported deep learning algorithms within the library.
+* [Open Assistant](https://projects.laion.ai/Open-Assistant/) is a chat-based assistant that understands tasks, can interact with their party systems, and retrieve information dynamically to do so.
+* [pytorch-accelerated](https://github.com/Chris-hughes10/pytorch-accelerated) is a lightweight training library, with a streamlined feature set centered around a general-purpose [Trainer](https://pytorch-accelerated.readthedocs.io/en/latest/trainer.html), that places a huge emphasis on simplicity and transparency; enabling users to understand exactly what is going on under the hood, but without having to write and maintain the boilerplate themselves!
+* [Stable Diffusion web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui) is an open-source browser-based easy-to-use interface based on the Gradio library for Stable Diffusion.
+* [torchkeras](https://github.com/lyhue1991/torchkeras) is a simple tool for training pytorch model just in a keras style, a dynamic and beautiful plot is provided in notebook to monitor your loss or metric.
+* [transformers](https://github.com/huggingface/transformers) as a tool for helping train state-of-the-art machine learning models in PyTorch, Tensorflow, and JAX. (Accelerate is the backend for the PyTorch side).
+
+
+## Installation
+
+This repository is tested on Python 3.8+ and PyTorch 1.10.0+
+
+You should install ๐ค Accelerate in a [virtual environment](https://docs.python.org/3/library/venv.html). If you're unfamiliar with Python virtual environments, check out the [user guide](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).
+
+First, create a virtual environment with the version of Python you're going to use and activate it.
+
+Then, you will need to install PyTorch: refer to the [official installation page](https://pytorch.org/get-started/locally/#start-locally) regarding the specific install command for your platform. Then ๐ค Accelerate can be installed using pip as follows:
+
+```bash
+pip install accelerate
+```
+
+## Supported integrations
+
+- CPU only
+- multi-CPU on one node (machine)
+- multi-CPU on several nodes (machines)
+- single GPU
+- multi-GPU on one node (machine)
+- multi-GPU on several nodes (machines)
+- TPU
+- FP16/BFloat16 mixed precision
+- FP8 mixed precision with [Transformer Engine](https://github.com/NVIDIA/TransformerEngine) or [MS-AMP](https://github.com/Azure/MS-AMP/)
+- DeepSpeed support (Experimental)
+- PyTorch Fully Sharded Data Parallel (FSDP) support (Experimental)
+- Megatron-LM support (Experimental)
+
+## Citing ๐ค Accelerate
+
+If you use ๐ค Accelerate in your publication, please cite it by using the following BibTeX entry.
+
+```bibtex
+@Misc{accelerate,
+ title = {Accelerate: Training and inference at scale made simple, efficient and adaptable.},
+ author = {Sylvain Gugger and Lysandre Debut and Thomas Wolf and Philipp Schmid and Zachary Mueller and Sourab Mangrulkar and Marc Sun and Benjamin Bossan},
+ howpublished = {\url{https://github.com/huggingface/accelerate}},
+ year = {2022}
+}
+```
diff --git a/agents/__init__.py b/agents/__init__.py
new file mode 100644
index 0000000..70762c2
--- /dev/null
+++ b/agents/__init__.py
@@ -0,0 +1,69 @@
+#!/usr/bin/env python
+# coding=utf-8
+
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ..utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_torch_available,
+)
+
+
+_import_structure = {
+ "agents": ["Agent", "CodeAgent", "ManagedAgent", "ReactAgent", "ReactCodeAgent", "ReactJsonAgent", "Toolbox"],
+ "llm_engine": ["HfApiEngine", "TransformersEngine"],
+ "monitoring": ["stream_to_gradio"],
+ "tools": ["PipelineTool", "Tool", "ToolCollection", "launch_gradio_demo", "load_tool", "tool"],
+}
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["default_tools"] = ["FinalAnswerTool", "PythonInterpreterTool"]
+ _import_structure["document_question_answering"] = ["DocumentQuestionAnsweringTool"]
+ _import_structure["image_question_answering"] = ["ImageQuestionAnsweringTool"]
+ _import_structure["search"] = ["DuckDuckGoSearchTool", "VisitWebpageTool"]
+ _import_structure["speech_to_text"] = ["SpeechToTextTool"]
+ _import_structure["text_to_speech"] = ["TextToSpeechTool"]
+ _import_structure["translation"] = ["TranslationTool"]
+
+if TYPE_CHECKING:
+ from .agents import Agent, CodeAgent, ManagedAgent, ReactAgent, ReactCodeAgent, ReactJsonAgent, Toolbox
+ from .llm_engine import HfApiEngine, TransformersEngine
+ from .monitoring import stream_to_gradio
+ from .tools import PipelineTool, Tool, ToolCollection, launch_gradio_demo, load_tool, tool
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .default_tools import FinalAnswerTool, PythonInterpreterTool
+ from .document_question_answering import DocumentQuestionAnsweringTool
+ from .image_question_answering import ImageQuestionAnsweringTool
+ from .search import DuckDuckGoSearchTool, VisitWebpageTool
+ from .speech_to_text import SpeechToTextTool
+ from .text_to_speech import TextToSpeechTool
+ from .translation import TranslationTool
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/agents/agent_types.py b/agents/agent_types.py
new file mode 100644
index 0000000..f5be746
--- /dev/null
+++ b/agents/agent_types.py
@@ -0,0 +1,260 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import pathlib
+import tempfile
+import uuid
+
+import numpy as np
+
+from ..utils import is_soundfile_availble, is_torch_available, is_vision_available, logging
+
+
+logger = logging.get_logger(__name__)
+
+if is_vision_available():
+ from PIL import Image
+ from PIL.Image import Image as ImageType
+else:
+ ImageType = object
+
+if is_torch_available():
+ import torch
+ from torch import Tensor
+else:
+ Tensor = object
+
+if is_soundfile_availble():
+ import soundfile as sf
+
+
+class AgentType:
+ """
+ Abstract class to be reimplemented to define types that can be returned by agents.
+
+ These objects serve three purposes:
+
+ - They behave as they were the type they're meant to be, e.g., a string for text, a PIL.Image for images
+ - They can be stringified: str(object) in order to return a string defining the object
+ - They should be displayed correctly in ipython notebooks/colab/jupyter
+ """
+
+ def __init__(self, value):
+ self._value = value
+
+ def __str__(self):
+ return self.to_string()
+
+ def to_raw(self):
+ logger.error(
+ "This is a raw AgentType of unknown type. Display in notebooks and string conversion will be unreliable"
+ )
+ return self._value
+
+ def to_string(self) -> str:
+ logger.error(
+ "This is a raw AgentType of unknown type. Display in notebooks and string conversion will be unreliable"
+ )
+ return str(self._value)
+
+
+class AgentText(AgentType, str):
+ """
+ Text type returned by the agent. Behaves as a string.
+ """
+
+ def to_raw(self):
+ return self._value
+
+ def to_string(self):
+ return str(self._value)
+
+
+class AgentImage(AgentType, ImageType):
+ """
+ Image type returned by the agent. Behaves as a PIL.Image.
+ """
+
+ def __init__(self, value):
+ AgentType.__init__(self, value)
+ ImageType.__init__(self)
+
+ if not is_vision_available():
+ raise ImportError("PIL must be installed in order to handle images.")
+
+ self._path = None
+ self._raw = None
+ self._tensor = None
+
+ if isinstance(value, ImageType):
+ self._raw = value
+ elif isinstance(value, (str, pathlib.Path)):
+ self._path = value
+ elif isinstance(value, torch.Tensor):
+ self._tensor = value
+ elif isinstance(value, np.ndarray):
+ self._tensor = torch.from_numpy(value)
+ else:
+ raise TypeError(f"Unsupported type for {self.__class__.__name__}: {type(value)}")
+
+ def _ipython_display_(self, include=None, exclude=None):
+ """
+ Displays correctly this type in an ipython notebook (ipython, colab, jupyter, ...)
+ """
+ from IPython.display import Image, display
+
+ display(Image(self.to_string()))
+
+ def to_raw(self):
+ """
+ Returns the "raw" version of that object. In the case of an AgentImage, it is a PIL.Image.
+ """
+ if self._raw is not None:
+ return self._raw
+
+ if self._path is not None:
+ self._raw = Image.open(self._path)
+ return self._raw
+
+ if self._tensor is not None:
+ array = self._tensor.cpu().detach().numpy()
+ return Image.fromarray((255 - array * 255).astype(np.uint8))
+
+ def to_string(self):
+ """
+ Returns the stringified version of that object. In the case of an AgentImage, it is a path to the serialized
+ version of the image.
+ """
+ if self._path is not None:
+ return self._path
+
+ if self._raw is not None:
+ directory = tempfile.mkdtemp()
+ self._path = os.path.join(directory, str(uuid.uuid4()) + ".png")
+ self._raw.save(self._path)
+ return self._path
+
+ if self._tensor is not None:
+ array = self._tensor.cpu().detach().numpy()
+
+ # There is likely simpler than load into image into save
+ img = Image.fromarray((255 - array * 255).astype(np.uint8))
+
+ directory = tempfile.mkdtemp()
+ self._path = os.path.join(directory, str(uuid.uuid4()) + ".png")
+
+ img.save(self._path)
+
+ return self._path
+
+ def save(self, output_bytes, format, **params):
+ """
+ Saves the image to a file.
+ Args:
+ output_bytes (bytes): The output bytes to save the image to.
+ format (str): The format to use for the output image. The format is the same as in PIL.Image.save.
+ **params: Additional parameters to pass to PIL.Image.save.
+ """
+ img = self.to_raw()
+ img.save(output_bytes, format, **params)
+
+
+class AgentAudio(AgentType, str):
+ """
+ Audio type returned by the agent.
+ """
+
+ def __init__(self, value, samplerate=16_000):
+ super().__init__(value)
+
+ if not is_soundfile_availble():
+ raise ImportError("soundfile must be installed in order to handle audio.")
+
+ self._path = None
+ self._tensor = None
+
+ self.samplerate = samplerate
+ if isinstance(value, (str, pathlib.Path)):
+ self._path = value
+ elif is_torch_available() and isinstance(value, torch.Tensor):
+ self._tensor = value
+ elif isinstance(value, tuple):
+ self.samplerate = value[0]
+ if isinstance(value[1], np.ndarray):
+ self._tensor = torch.from_numpy(value[1])
+ else:
+ self._tensor = torch.tensor(value[1])
+ else:
+ raise ValueError(f"Unsupported audio type: {type(value)}")
+
+ def _ipython_display_(self, include=None, exclude=None):
+ """
+ Displays correctly this type in an ipython notebook (ipython, colab, jupyter, ...)
+ """
+ from IPython.display import Audio, display
+
+ display(Audio(self.to_string(), rate=self.samplerate))
+
+ def to_raw(self):
+ """
+ Returns the "raw" version of that object. It is a `torch.Tensor` object.
+ """
+ if self._tensor is not None:
+ return self._tensor
+
+ if self._path is not None:
+ tensor, self.samplerate = sf.read(self._path)
+ self._tensor = torch.tensor(tensor)
+ return self._tensor
+
+ def to_string(self):
+ """
+ Returns the stringified version of that object. In the case of an AgentAudio, it is a path to the serialized
+ version of the audio.
+ """
+ if self._path is not None:
+ return self._path
+
+ if self._tensor is not None:
+ directory = tempfile.mkdtemp()
+ self._path = os.path.join(directory, str(uuid.uuid4()) + ".wav")
+ sf.write(self._path, self._tensor, samplerate=self.samplerate)
+ return self._path
+
+
+AGENT_TYPE_MAPPING = {"string": AgentText, "image": AgentImage, "audio": AgentAudio}
+INSTANCE_TYPE_MAPPING = {str: AgentText, ImageType: AgentImage}
+
+if is_torch_available():
+ INSTANCE_TYPE_MAPPING[Tensor] = AgentAudio
+
+
+def handle_agent_inputs(*args, **kwargs):
+ args = [(arg.to_raw() if isinstance(arg, AgentType) else arg) for arg in args]
+ kwargs = {k: (v.to_raw() if isinstance(v, AgentType) else v) for k, v in kwargs.items()}
+ return args, kwargs
+
+
+def handle_agent_outputs(output, output_type=None):
+ if output_type in AGENT_TYPE_MAPPING:
+ # If the class has defined outputs, we can map directly according to the class definition
+ decoded_outputs = AGENT_TYPE_MAPPING[output_type](output)
+ return decoded_outputs
+ else:
+ # If the class does not have defined output, then we map according to the type
+ for _k, _v in INSTANCE_TYPE_MAPPING.items():
+ if isinstance(output, _k):
+ return _v(output)
+ return output
diff --git a/agents/agents.py b/agents/agents.py
new file mode 100644
index 0000000..08c30d5
--- /dev/null
+++ b/agents/agents.py
@@ -0,0 +1,1278 @@
+#!/usr/bin/env python
+# coding=utf-8
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import json
+import logging
+import re
+import time
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+from .. import is_torch_available
+from ..utils import logging as transformers_logging
+from ..utils.import_utils import is_pygments_available
+from .agent_types import AgentAudio, AgentImage
+from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool, setup_default_tools
+from .llm_engine import HfApiEngine, MessageRole
+from .monitoring import Monitor
+from .prompts import (
+ DEFAULT_CODE_SYSTEM_PROMPT,
+ DEFAULT_REACT_CODE_SYSTEM_PROMPT,
+ DEFAULT_REACT_JSON_SYSTEM_PROMPT,
+ PLAN_UPDATE_FINAL_PLAN_REDACTION,
+ PROMPTS_FOR_INITIAL_PLAN,
+ PROMPTS_FOR_PLAN_UPDATE,
+ SUPPORTED_PLAN_TYPES,
+ SYSTEM_PROMPT_FACTS,
+ SYSTEM_PROMPT_FACTS_UPDATE,
+ USER_PROMPT_FACTS_UPDATE,
+)
+from .python_interpreter import LIST_SAFE_MODULES, evaluate_python_code
+from .tools import (
+ DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
+ Tool,
+ get_tool_description_with_args,
+ load_tool,
+)
+
+
+if is_pygments_available():
+ from pygments import highlight
+ from pygments.formatters import Terminal256Formatter
+ from pygments.lexers import PythonLexer
+
+
+class CustomFormatter(logging.Formatter):
+ grey = "\x1b[38;20m"
+ bold_yellow = "\x1b[33;1m"
+ red = "\x1b[31;20m"
+ green = "\x1b[32;20m"
+ bold_green = "\x1b[32;20;1m"
+ bold_red = "\x1b[31;1m"
+ bold_white = "\x1b[37;1m"
+ orange = "\x1b[38;5;214m"
+ bold_orange = "\x1b[38;5;214;1m"
+ reset = "\x1b[0m"
+ format = "%(message)s"
+
+ FORMATS = {
+ logging.DEBUG: grey + format + reset,
+ logging.INFO: format,
+ logging.WARNING: bold_yellow + format + reset,
+ logging.ERROR: red + format + reset,
+ logging.CRITICAL: bold_red + format + reset,
+ 31: reset + format + reset,
+ 32: green + format + reset,
+ 33: bold_green + format + reset,
+ 34: bold_white + format + reset,
+ 35: orange + format + reset,
+ 36: bold_orange + format + reset,
+ }
+
+ def format(self, record):
+ log_fmt = self.FORMATS.get(record.levelno)
+ formatter = logging.Formatter(log_fmt)
+ return formatter.format(record)
+
+
+logger = transformers_logging.get_logger(__name__)
+logger.propagate = False
+ch = logging.StreamHandler()
+ch.setFormatter(CustomFormatter())
+logger.addHandler(ch)
+
+
+def parse_json_blob(json_blob: str) -> Dict[str, str]:
+ try:
+ first_accolade_index = json_blob.find("{")
+ last_accolade_index = [a.start() for a in list(re.finditer("}", json_blob))][-1]
+ json_blob = json_blob[first_accolade_index : last_accolade_index + 1].replace('\\"', "'")
+ json_data = json.loads(json_blob, strict=False)
+ return json_data
+ except json.JSONDecodeError as e:
+ place = e.pos
+ if json_blob[place - 1 : place + 2] == "},\n":
+ raise ValueError(
+ "JSON is invalid: you probably tried to provide multiple tool calls in one action. PROVIDE ONLY ONE TOOL CALL."
+ )
+ raise ValueError(
+ f"The JSON blob you used is invalid due to the following error: {e}.\n"
+ f"JSON blob was: {json_blob}, decoding failed on that specific part of the blob:\n"
+ f"'{json_blob[place-4:place+5]}'."
+ )
+ except Exception as e:
+ raise ValueError(f"Error in parsing the JSON blob: {e}")
+
+
+def parse_code_blob(code_blob: str) -> str:
+ try:
+ pattern = r"```(?:py|python)?\n(.*?)\n```"
+ match = re.search(pattern, code_blob, re.DOTALL)
+ return match.group(1).strip()
+ except Exception as e:
+ raise ValueError(
+ f"""
+The code blob you used is invalid: due to the following error: {e}
+This means that the regex pattern {pattern} was not respected: make sure to include code with the correct pattern, for instance:
+Thoughts: Your thoughts
+Code:
+```py
+# Your python code here
+```"""
+ )
+
+
+def parse_json_tool_call(json_blob: str) -> Tuple[str, Dict[str, str]]:
+ json_blob = json_blob.replace("```json", "").replace("```", "")
+ tool_call = parse_json_blob(json_blob)
+ if "action" in tool_call and "action_input" in tool_call:
+ return tool_call["action"], tool_call["action_input"]
+ elif "action" in tool_call:
+ return tool_call["action"], None
+ else:
+ raise ValueError(
+ f"Missing keys: {[key for key in ['action', 'action_input'] if key not in tool_call]} in blob {tool_call}"
+ )
+
+
+def parse_text_tool_call(text: str) -> Tuple[str, Union[str, Dict[str, str]]]:
+ """
+ Expects a text in the format: 'Action:', 'Action input:', 'Observation:'. 'Action input:' contains a json string with input arguments.
+ """
+ try:
+ if "Observation:" in text:
+ text = text.split("Observation:")[0]
+ if "Action:" in text:
+ text = text.split("Action:")[1]
+ tool_name, tool_input = text.split("Action input:")
+ if "{" in tool_input:
+ tool_input = parse_json_blob(tool_input)
+ else:
+ tool_input = tool_input.strip().replace('"', "")
+ return tool_name.strip().replace('"', "").replace("\\", ""), tool_input
+ except Exception as e:
+ raise ValueError(
+ f"Error in parsing the text tool call: {e}. Be sure to provide the correct format. DO NOT repeat your previous incorrect tool call."
+ )
+
+
+def to_text(input: Union[List[Dict[str, str]], Dict[str, str], str]) -> str:
+ if isinstance(input, list):
+ return "\n".join([m["content"] for m in input])
+ elif isinstance(input, dict):
+ return input["content"]
+ else:
+ return input
+
+
+HUGGINGFACE_DEFAULT_TOOLS = {}
+_tools_are_initialized = False
+
+
+class Toolbox:
+ """
+ The toolbox contains all tools that the agent can perform operations with, as well as a few methods to
+ manage them.
+
+ Args:
+ tools (`List[Tool]`):
+ The list of tools to instantiate the toolbox with
+ add_base_tools (`bool`, defaults to `False`, *optional*, defaults to `False`):
+ Whether to add the tools available within `transformers` to the toolbox.
+ """
+
+ def __init__(self, tools: List[Tool], add_base_tools: bool = False):
+ self._tools = {tool.name: tool for tool in tools}
+ if add_base_tools:
+ self.add_base_tools()
+ self._load_tools_if_needed()
+
+ def add_base_tools(self, add_python_interpreter: bool = False):
+ global _tools_are_initialized
+ global HUGGINGFACE_DEFAULT_TOOLS
+ if not _tools_are_initialized:
+ HUGGINGFACE_DEFAULT_TOOLS = setup_default_tools(logger)
+ _tools_are_initialized = True
+ for tool in HUGGINGFACE_DEFAULT_TOOLS.values():
+ if tool.name != "python_interpreter" or add_python_interpreter:
+ self.add_tool(tool)
+ self._load_tools_if_needed()
+
+ @property
+ def tools(self) -> Dict[str, Tool]:
+ """Get all tools currently in the toolbox"""
+ return self._tools
+
+ def show_tool_descriptions(self, tool_description_template: str = None) -> str:
+ """
+ Returns the description of all tools in the toolbox
+
+ Args:
+ tool_description_template (`str`, *optional*):
+ The template to use to describe the tools. If not provided, the default template will be used.
+ """
+ return "\n".join(
+ [get_tool_description_with_args(tool, tool_description_template) for tool in self._tools.values()]
+ )
+
+ def add_tool(self, tool: Tool):
+ """
+ Adds a tool to the toolbox
+
+ Args:
+ tool (`Tool`):
+ The tool to add to the toolbox.
+ """
+ if tool.name in self._tools:
+ raise KeyError(f"Error: tool '{tool.name}' already exists in the toolbox.")
+ self._tools[tool.name] = tool
+
+ def remove_tool(self, tool_name: str):
+ """
+ Removes a tool from the toolbox
+
+ Args:
+ tool_name (`str`):
+ The tool to remove from the toolbox.
+ """
+ if tool_name not in self._tools:
+ raise KeyError(
+ f"Error: tool {tool_name} not found in toolbox for removal, should be instead one of {list(self._tools.keys())}."
+ )
+ del self._tools[tool_name]
+
+ def update_tool(self, tool: Tool):
+ """
+ Updates a tool in the toolbox according to its name.
+
+ Args:
+ tool (`Tool`):
+ The tool to update to the toolbox.
+ """
+ if tool.name not in self._tools:
+ raise KeyError(
+ f"Error: tool {tool.name} not found in toolbox for update, should be instead one of {list(self._tools.keys())}."
+ )
+ self._tools[tool.name] = tool
+
+ def clear_toolbox(self):
+ """Clears the toolbox"""
+ self._tools = {}
+
+ def _load_tools_if_needed(self):
+ for name, tool in self._tools.items():
+ if not isinstance(tool, Tool):
+ task_or_repo_id = tool.task if tool.repo_id is None else tool.repo_id
+ self._tools[name] = load_tool(task_or_repo_id)
+
+ def __repr__(self):
+ toolbox_description = "Toolbox contents:\n"
+ for tool in self._tools.values():
+ toolbox_description += f"\t{tool.name}: {tool.description}\n"
+ return toolbox_description
+
+
+class AgentError(Exception):
+ """Base class for other agent-related exceptions"""
+
+ def __init__(self, message):
+ super().__init__(message)
+ self.message = message
+
+
+class AgentParsingError(AgentError):
+ """Exception raised for errors in parsing in the agent"""
+
+ pass
+
+
+class AgentExecutionError(AgentError):
+ """Exception raised for errors in execution in the agent"""
+
+ pass
+
+
+class AgentMaxIterationsError(AgentError):
+ """Exception raised for errors in execution in the agent"""
+
+ pass
+
+
+class AgentGenerationError(AgentError):
+ """Exception raised for errors in generation in the agent"""
+
+ pass
+
+
+def format_prompt_with_tools(toolbox: Toolbox, prompt_template: str, tool_description_template: str) -> str:
+ tool_descriptions = toolbox.show_tool_descriptions(tool_description_template)
+ prompt = prompt_template.replace("<>", tool_descriptions)
+
+ if "<>" in prompt:
+ tool_names = [f"'{tool_name}'" for tool_name in toolbox.tools.keys()]
+ prompt = prompt.replace("<>", ", ".join(tool_names))
+
+ return prompt
+
+
+def show_agents_descriptions(managed_agents: list):
+ managed_agents_descriptions = """
+You can also give requests to team members.
+Calling a team member works the same as for calling a tool: simply, the only argument you can give in the call is 'request', a long string explaning your request.
+Given that this team member is a real human, you should be very verbose in your request.
+Here is a list of the team members that you can call:"""
+ for agent in managed_agents.values():
+ managed_agents_descriptions += f"\n- {agent.name}: {agent.description}"
+ return managed_agents_descriptions
+
+
+def format_prompt_with_managed_agents_descriptions(prompt_template, managed_agents=None) -> str:
+ if managed_agents is not None:
+ return prompt_template.replace("<>", show_agents_descriptions(managed_agents))
+ else:
+ return prompt_template.replace("<>", "")
+
+
+def format_prompt_with_imports(prompt_template: str, authorized_imports: List[str]) -> str:
+ if "<>" not in prompt_template:
+ raise AgentError("Tag '<>' should be provided in the prompt.")
+ return prompt_template.replace("<>", str(authorized_imports))
+
+
+class Agent:
+ def __init__(
+ self,
+ tools: Union[List[Tool], Toolbox],
+ llm_engine: Callable = None,
+ system_prompt: Optional[str] = None,
+ tool_description_template: Optional[str] = None,
+ additional_args: Dict = {},
+ max_iterations: int = 6,
+ tool_parser: Optional[Callable] = None,
+ add_base_tools: bool = False,
+ verbose: int = 0,
+ grammar: Optional[Dict[str, str]] = None,
+ managed_agents: Optional[List] = None,
+ step_callbacks: Optional[List[Callable]] = None,
+ monitor_metrics: bool = True,
+ ):
+ if system_prompt is None:
+ system_prompt = DEFAULT_REACT_CODE_SYSTEM_PROMPT
+ if tool_parser is None:
+ tool_parser = parse_json_tool_call
+ self.agent_name = self.__class__.__name__
+ self.llm_engine = llm_engine
+ self.system_prompt_template = system_prompt
+ self.tool_description_template = (
+ tool_description_template if tool_description_template else DEFAULT_TOOL_DESCRIPTION_TEMPLATE
+ )
+ self.additional_args = additional_args
+ self.max_iterations = max_iterations
+ self.logger = logger
+ self.tool_parser = tool_parser
+ self.grammar = grammar
+
+ self.managed_agents = None
+ if managed_agents is not None:
+ self.managed_agents = {agent.name: agent for agent in managed_agents}
+
+ if isinstance(tools, Toolbox):
+ self._toolbox = tools
+ if add_base_tools:
+ if not is_torch_available():
+ raise ImportError("Using the base tools requires torch to be installed.")
+
+ self._toolbox.add_base_tools(add_python_interpreter=(self.__class__ == ReactJsonAgent))
+ else:
+ self._toolbox = Toolbox(tools, add_base_tools=add_base_tools)
+ self._toolbox.add_tool(FinalAnswerTool())
+
+ self.system_prompt = format_prompt_with_tools(
+ self._toolbox, self.system_prompt_template, self.tool_description_template
+ )
+ self.system_prompt = format_prompt_with_managed_agents_descriptions(self.system_prompt, self.managed_agents)
+ self.prompt = None
+ self.logs = []
+ self.task = None
+
+ if verbose == 0:
+ logger.setLevel(logging.WARNING)
+ elif verbose == 1:
+ logger.setLevel(logging.INFO)
+ elif verbose == 2:
+ logger.setLevel(logging.DEBUG)
+
+ # Initialize step callbacks
+ self.step_callbacks = step_callbacks if step_callbacks is not None else []
+
+ # Initialize Monitor if monitor_metrics is True
+ self.monitor = None
+ if monitor_metrics:
+ self.monitor = Monitor(self.llm_engine)
+ self.step_callbacks.append(self.monitor.update_metrics)
+
+ @property
+ def toolbox(self) -> Toolbox:
+ """Get the toolbox currently available to the agent"""
+ return self._toolbox
+
+ def initialize_for_run(self):
+ self.token_count = 0
+ self.system_prompt = format_prompt_with_tools(
+ self._toolbox,
+ self.system_prompt_template,
+ self.tool_description_template,
+ )
+ self.system_prompt = format_prompt_with_managed_agents_descriptions(self.system_prompt, self.managed_agents)
+ if hasattr(self, "authorized_imports"):
+ self.system_prompt = format_prompt_with_imports(
+ self.system_prompt, list(set(LIST_SAFE_MODULES) | set(self.authorized_imports))
+ )
+ self.logs = [{"system_prompt": self.system_prompt, "task": self.task}]
+ self.logger.log(33, "======== New task ========")
+ self.logger.log(34, self.task)
+ self.logger.debug("System prompt is as follows:")
+ self.logger.debug(self.system_prompt)
+
+ def write_inner_memory_from_logs(self, summary_mode: Optional[bool] = False) -> List[Dict[str, str]]:
+ """
+ Reads past llm_outputs, actions, and observations or errors from the logs into a series of messages
+ that can be used as input to the LLM.
+ """
+ prompt_message = {"role": MessageRole.SYSTEM, "content": self.logs[0]["system_prompt"]}
+ task_message = {
+ "role": MessageRole.USER,
+ "content": "Task: " + self.logs[0]["task"],
+ }
+ if summary_mode:
+ memory = [task_message]
+ else:
+ memory = [prompt_message, task_message]
+ for i, step_log in enumerate(self.logs[1:]):
+ if "llm_output" in step_log and not summary_mode:
+ thought_message = {"role": MessageRole.ASSISTANT, "content": step_log["llm_output"].strip()}
+ memory.append(thought_message)
+ if "facts" in step_log:
+ thought_message = {
+ "role": MessageRole.ASSISTANT,
+ "content": "[FACTS LIST]:\n" + step_log["facts"].strip(),
+ }
+ memory.append(thought_message)
+
+ if "plan" in step_log and not summary_mode:
+ thought_message = {"role": MessageRole.ASSISTANT, "content": "[PLAN]:\n" + step_log["plan"].strip()}
+ memory.append(thought_message)
+
+ if "tool_call" in step_log and summary_mode:
+ tool_call_message = {
+ "role": MessageRole.ASSISTANT,
+ "content": f"[STEP {i} TOOL CALL]: " + str(step_log["tool_call"]).strip(),
+ }
+ memory.append(tool_call_message)
+
+ if "task" in step_log:
+ tool_call_message = {
+ "role": MessageRole.USER,
+ "content": "New task:\n" + step_log["task"],
+ }
+ memory.append(tool_call_message)
+
+ if "error" in step_log or "observation" in step_log:
+ if "error" in step_log:
+ message_content = (
+ f"[OUTPUT OF STEP {i}] -> Error:\n"
+ + str(step_log["error"])
+ + "\nNow let's retry: take care not to repeat previous errors! If you have retried several times, try a completely different approach.\n"
+ )
+ elif "observation" in step_log:
+ message_content = f"[OUTPUT OF STEP {i}] -> Observation:\n{step_log['observation']}"
+ tool_response_message = {"role": MessageRole.TOOL_RESPONSE, "content": message_content}
+ memory.append(tool_response_message)
+
+ return memory
+
+ def get_succinct_logs(self):
+ return [{key: value for key, value in log.items() if key != "agent_memory"} for log in self.logs]
+
+ def extract_action(self, llm_output: str, split_token: str) -> str:
+ """
+ Parse action from the LLM output
+
+ Args:
+ llm_output (`str`): Output of the LLM
+ split_token (`str`): Separator for the action. Should match the example in the system prompt.
+ """
+ try:
+ split = llm_output.split(split_token)
+ rationale, action = (
+ split[-2],
+ split[-1],
+ ) # NOTE: using indexes starting from the end solves for when you have more than one split_token in the output
+ except Exception as e:
+ self.logger.error(e, exc_info=1)
+ raise AgentParsingError(
+ f"Error: No '{split_token}' token provided in your output.\nYour output:\n{llm_output}\n. Be sure to include an action, prefaced with '{split_token}'!"
+ )
+ return rationale.strip(), action.strip()
+
+ def execute_tool_call(self, tool_name: str, arguments: Dict[str, str]) -> Any:
+ """
+ Execute tool with the provided input and returns the result.
+ This method replaces arguments with the actual values from the state if they refer to state variables.
+
+ Args:
+ tool_name (`str`): Name of the Tool to execute (should be one from self.toolbox).
+ arguments (Dict[str, str]): Arguments passed to the Tool.
+ """
+ available_tools = self.toolbox.tools
+ if self.managed_agents is not None:
+ available_tools = {**available_tools, **self.managed_agents}
+ if tool_name not in available_tools:
+ error_msg = f"Error: unknown tool {tool_name}, should be instead one of {list(available_tools.keys())}."
+ self.logger.error(error_msg, exc_info=1)
+ raise AgentExecutionError(error_msg)
+
+ try:
+ if isinstance(arguments, str):
+ observation = available_tools[tool_name](arguments)
+ elif isinstance(arguments, dict):
+ for key, value in arguments.items():
+ # if the value is the name of a state variable like "image.png", replace it with the actual value
+ if isinstance(value, str) and value in self.state:
+ arguments[key] = self.state[value]
+ observation = available_tools[tool_name](**arguments)
+ else:
+ raise AgentExecutionError(
+ f"Arguments passed to tool should be a dict or string: got a {type(arguments)}."
+ )
+ return observation
+ except Exception as e:
+ if tool_name in self.toolbox.tools:
+ raise AgentExecutionError(
+ f"Error in tool call execution: {e}\nYou should only use this tool with a correct input.\n"
+ f"As a reminder, this tool's description is the following:\n{get_tool_description_with_args(available_tools[tool_name])}"
+ )
+ elif tool_name in self.managed_agents:
+ raise AgentExecutionError(
+ f"Error in calling team member: {e}\nYou should only ask this team member with a correct request.\n"
+ f"As a reminder, this team member's description is the following:\n{available_tools[tool_name]}"
+ )
+
+ def log_rationale_code_action(self, rationale: str, code_action: str) -> None:
+ self.logger.warning("=== Agent thoughts:")
+ self.logger.log(31, rationale)
+ self.logger.warning(">>> Agent is executing the code below:")
+ if is_pygments_available():
+ self.logger.log(
+ 31, highlight(code_action, PythonLexer(ensurenl=False), Terminal256Formatter(style="nord"))
+ )
+ else:
+ self.logger.log(31, code_action)
+ self.logger.warning("====")
+
+ def run(self, **kwargs):
+ """To be implemented in the child class"""
+ raise NotImplementedError
+
+
+class CodeAgent(Agent):
+ """
+ A class for an agent that solves the given task using a single block of code. It plans all its actions, then executes all in one shot.
+ """
+
+ def __init__(
+ self,
+ tools: List[Tool],
+ llm_engine: Optional[Callable] = None,
+ system_prompt: Optional[str] = None,
+ tool_description_template: Optional[str] = None,
+ grammar: Optional[Dict[str, str]] = None,
+ additional_authorized_imports: Optional[List[str]] = None,
+ **kwargs,
+ ):
+ if llm_engine is None:
+ llm_engine = HfApiEngine()
+ if system_prompt is None:
+ system_prompt = DEFAULT_CODE_SYSTEM_PROMPT
+ if tool_description_template is None:
+ tool_description_template = DEFAULT_TOOL_DESCRIPTION_TEMPLATE
+ super().__init__(
+ tools=tools,
+ llm_engine=llm_engine,
+ system_prompt=system_prompt,
+ tool_description_template=tool_description_template,
+ grammar=grammar,
+ **kwargs,
+ )
+
+ if not is_pygments_available():
+ transformers_logging.warning_once(
+ logger,
+ "pygments isn't installed. Installing pygments will enable color syntax highlighting in the "
+ "CodeAgent.",
+ )
+
+ self.python_evaluator = evaluate_python_code
+ self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else []
+ self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports))
+ self.system_prompt = self.system_prompt.replace("<>", str(self.authorized_imports))
+
+ def parse_code_blob(self, result: str) -> str:
+ """
+ Override this method if you want to change the way the code is
+ cleaned in the `run` method.
+ """
+ return parse_code_blob(result)
+
+ def run(self, task: str, return_generated_code: bool = False, **kwargs):
+ """
+ Runs the agent for the given task.
+
+ Args:
+ task (`str`): The task to perform
+ return_generated_code (`bool`, *optional*, defaults to `False`): Whether to return the generated code instead of running it
+ kwargs (additional keyword arguments, *optional*):
+ Any keyword argument to send to the agent when evaluating the code.
+
+ Example:
+
+ ```py
+ from transformers.agents import CodeAgent
+
+ agent = CodeAgent(tools=[])
+ agent.run("What is the result of 2 power 3.7384?")
+ ```
+ """
+ self.task = task
+ if len(kwargs) > 0:
+ self.task += f"\nYou have been provided with these initial arguments: {str(kwargs)}."
+ self.state = kwargs.copy()
+ self.initialize_for_run()
+
+ # Run LLM
+ prompt_message = {"role": MessageRole.SYSTEM, "content": self.system_prompt}
+ task_message = {
+ "role": MessageRole.USER,
+ "content": "Task: " + self.task,
+ }
+
+ self.prompt = [prompt_message, task_message]
+ self.logger.info("====Executing with this prompt====")
+ self.logger.info(self.prompt)
+
+ additional_args = {"grammar": self.grammar} if self.grammar is not None else {}
+ llm_output = self.llm_engine(self.prompt, stop_sequences=[""], **additional_args)
+
+ if return_generated_code:
+ return llm_output
+
+ # Parse
+ try:
+ rationale, code_action = self.extract_action(llm_output=llm_output, split_token="Code:")
+ except Exception as e:
+ self.logger.debug(
+ f"Error in extracting action, trying to parse the whole output as code. Error trace: {e}"
+ )
+ rationale, code_action = "", llm_output
+
+ try:
+ code_action = self.parse_code_blob(code_action)
+ except Exception as e:
+ error_msg = f"Error in code parsing: {e}. Be sure to provide correct code"
+ self.logger.error(error_msg, exc_info=1)
+ return error_msg
+
+ # Execute
+ self.log_rationale_code_action(rationale, code_action)
+ try:
+ available_tools = {**BASE_PYTHON_TOOLS.copy(), **self.toolbox.tools}
+ output = self.python_evaluator(
+ code_action,
+ static_tools=available_tools,
+ custom_tools={},
+ state=self.state,
+ authorized_imports=self.authorized_imports,
+ )
+ self.logger.info(self.state["print_outputs"])
+ return output
+ except Exception as e:
+ error_msg = f"Error in execution: {e}. Be sure to provide correct code."
+ self.logger.error(error_msg, exc_info=1)
+ return error_msg
+
+
+class ReactAgent(Agent):
+ """
+ This agent that solves the given task step by step, using the ReAct framework:
+ While the objective is not reached, the agent will perform a cycle of thinking and acting.
+ The action will be parsed from the LLM output: it consists in calls to tools from the toolbox, with arguments chosen by the LLM engine.
+ """
+
+ def __init__(
+ self,
+ tools: List[Tool],
+ llm_engine: Optional[Callable] = None,
+ system_prompt: Optional[str] = None,
+ tool_description_template: Optional[str] = None,
+ grammar: Optional[Dict[str, str]] = None,
+ plan_type: Optional[str] = None,
+ planning_interval: Optional[int] = None,
+ **kwargs,
+ ):
+ if llm_engine is None:
+ llm_engine = HfApiEngine()
+ if system_prompt is None:
+ system_prompt = DEFAULT_REACT_CODE_SYSTEM_PROMPT
+ if tool_description_template is None:
+ tool_description_template = DEFAULT_TOOL_DESCRIPTION_TEMPLATE
+ if plan_type is None:
+ plan_type = SUPPORTED_PLAN_TYPES[0]
+ else:
+ assert plan_type in SUPPORTED_PLAN_TYPES, f"plan type {plan_type} is not supported"
+ super().__init__(
+ tools=tools,
+ llm_engine=llm_engine,
+ system_prompt=system_prompt,
+ tool_description_template=tool_description_template,
+ grammar=grammar,
+ **kwargs,
+ )
+ self.planning_interval = planning_interval
+ self.plan_type = plan_type
+
+ def provide_final_answer(self, task) -> str:
+ """
+ This method provides a final answer to the task, based on the logs of the agent's interactions.
+ """
+ self.prompt = [
+ {
+ "role": MessageRole.SYSTEM,
+ "content": "An agent tried to answer an user query but it got stuck and failed to do so. You are tasked with providing an answer instead. Here is the agent's memory:",
+ }
+ ]
+ self.prompt += self.write_inner_memory_from_logs()[1:]
+ self.prompt += [
+ {
+ "role": MessageRole.USER,
+ "content": f"Based on the above, please provide an answer to the following user request:\n{task}",
+ }
+ ]
+ try:
+ return self.llm_engine(self.prompt)
+ except Exception as e:
+ return f"Error in generating final llm output: {e}."
+
+ def run(self, task: str, stream: bool = False, reset: bool = True, **kwargs):
+ """
+ Runs the agent for the given task.
+
+ Args:
+ task (`str`): The task to perform
+
+ Example:
+ ```py
+ from transformers.agents import ReactCodeAgent
+ agent = ReactCodeAgent(tools=[])
+ agent.run("What is the result of 2 power 3.7384?")
+ ```
+ """
+ self.task = task
+ if len(kwargs) > 0:
+ self.task += f"\nYou have been provided with these initial arguments: {str(kwargs)}."
+ self.state = kwargs.copy()
+ if reset:
+ self.initialize_for_run()
+ else:
+ self.logs.append({"task": task})
+ if stream:
+ return self.stream_run(task)
+ else:
+ return self.direct_run(task)
+
+ def stream_run(self, task: str):
+ """
+ Runs the agent in streaming mode, yielding steps as they are executed: should be launched only in the `run` method.
+ """
+ final_answer = None
+ iteration = 0
+ while final_answer is None and iteration < self.max_iterations:
+ step_start_time = time.time()
+ step_log_entry = {"iteration": iteration, "start_time": step_start_time}
+ try:
+ self.step(step_log_entry)
+ if "final_answer" in step_log_entry:
+ final_answer = step_log_entry["final_answer"]
+ except AgentError as e:
+ self.logger.error(e, exc_info=1)
+ step_log_entry["error"] = e
+ finally:
+ step_end_time = time.time()
+ step_log_entry["step_end_time"] = step_end_time
+ step_log_entry["step_duration"] = step_end_time - step_start_time
+ self.logs.append(step_log_entry)
+ for callback in self.step_callbacks:
+ callback(step_log_entry)
+ iteration += 1
+ yield step_log_entry
+
+ if final_answer is None and iteration == self.max_iterations:
+ error_message = "Reached max iterations."
+ final_step_log = {"error": AgentMaxIterationsError(error_message)}
+ self.logs.append(final_step_log)
+ self.logger.error(error_message, exc_info=1)
+ final_answer = self.provide_final_answer(task)
+ final_step_log["final_answer"] = final_answer
+ final_step_log["step_duration"] = 0
+ for callback in self.step_callbacks:
+ callback(final_step_log)
+ yield final_step_log
+
+ yield final_answer
+
+ def direct_run(self, task: str):
+ """
+ Runs the agent in direct mode, returning outputs only at the end: should be launched only in the `run` method.
+ """
+ final_answer = None
+ iteration = 0
+ while final_answer is None and iteration < self.max_iterations:
+ step_start_time = time.time()
+ step_log_entry = {"iteration": iteration, "start_time": step_start_time}
+ try:
+ if self.planning_interval is not None and iteration % self.planning_interval == 0:
+ self.planning_step(task, is_first_step=(iteration == 0), iteration=iteration)
+ self.step(step_log_entry)
+ if "final_answer" in step_log_entry:
+ final_answer = step_log_entry["final_answer"]
+ except AgentError as e:
+ self.logger.error(e, exc_info=1)
+ step_log_entry["error"] = e
+ finally:
+ step_end_time = time.time()
+ step_log_entry["step_end_time"] = step_end_time
+ step_log_entry["step_duration"] = step_end_time - step_start_time
+ self.logs.append(step_log_entry)
+ for callback in self.step_callbacks:
+ callback(step_log_entry)
+ iteration += 1
+
+ if final_answer is None and iteration == self.max_iterations:
+ error_message = "Reached max iterations."
+ final_step_log = {"error": AgentMaxIterationsError(error_message)}
+ self.logs.append(final_step_log)
+ self.logger.error(error_message, exc_info=1)
+ final_answer = self.provide_final_answer(task)
+ final_step_log["final_answer"] = final_answer
+ final_step_log["step_duration"] = 0
+ for callback in self.step_callbacks:
+ callback(final_step_log)
+
+ return final_answer
+
+ def planning_step(self, task, is_first_step: bool = False, iteration: int = None):
+ """
+ Used periodically by the agent to plan the next steps to reach the objective.
+
+ Args:
+ task (`str`): The task to perform
+ is_first_step (`bool`): If this step is not the first one, the plan should be an update over a previous plan.
+ iteration (`int`): The number of the current step, used as an indication for the LLM.
+ """
+ if is_first_step:
+ message_prompt_facts = {"role": MessageRole.SYSTEM, "content": SYSTEM_PROMPT_FACTS}
+ message_prompt_task = {
+ "role": MessageRole.USER,
+ "content": f"""Here is the task:
+```
+{task}
+```
+Now begin!""",
+ }
+
+ answer_facts = self.llm_engine([message_prompt_facts, message_prompt_task])
+
+ message_system_prompt_plan = {
+ "role": MessageRole.SYSTEM,
+ "content": PROMPTS_FOR_INITIAL_PLAN[self.plan_type]["system"],
+ }
+ message_user_prompt_plan = {
+ "role": MessageRole.USER,
+ "content": PROMPTS_FOR_INITIAL_PLAN[self.plan_type]["user"].format(
+ task=task,
+ tool_descriptions=self._toolbox.show_tool_descriptions(self.tool_description_template),
+ managed_agents_descriptions=(
+ show_agents_descriptions(self.managed_agents) if self.managed_agents is not None else ""
+ ),
+ answer_facts=answer_facts,
+ ),
+ }
+ answer_plan = self.llm_engine(
+ [message_system_prompt_plan, message_user_prompt_plan], stop_sequences=[""]
+ )
+
+ final_plan_redaction = f"""Here is the plan of action that I will follow to solve the task:
+```
+{answer_plan}
+```"""
+ final_facts_redaction = f"""Here are the facts that I know so far:
+```
+{answer_facts}
+```""".strip()
+ self.logs.append({"plan": final_plan_redaction, "facts": final_facts_redaction})
+ self.logger.log(36, "===== Initial plan =====")
+ self.logger.log(35, final_plan_redaction)
+ else: # update plan
+ agent_memory = self.write_inner_memory_from_logs(
+ summary_mode=False
+ ) # This will not log the plan but will log facts
+
+ # Redact updated facts
+ facts_update_system_prompt = {
+ "role": MessageRole.SYSTEM,
+ "content": SYSTEM_PROMPT_FACTS_UPDATE,
+ }
+ facts_update_message = {
+ "role": MessageRole.USER,
+ "content": USER_PROMPT_FACTS_UPDATE,
+ }
+ facts_update = self.llm_engine([facts_update_system_prompt] + agent_memory + [facts_update_message])
+
+ # Redact updated plan
+ plan_update_message = {
+ "role": MessageRole.SYSTEM,
+ "content": PROMPTS_FOR_PLAN_UPDATE[self.plan_type]["system"].format(task=task),
+ }
+ plan_update_message_user = {
+ "role": MessageRole.USER,
+ "content": PROMPTS_FOR_PLAN_UPDATE[self.plan_type]["user"].format(
+ task=task,
+ tool_descriptions=self._toolbox.show_tool_descriptions(self.tool_description_template),
+ managed_agents_descriptions=(
+ show_agents_descriptions(self.managed_agents) if self.managed_agents is not None else ""
+ ),
+ facts_update=facts_update,
+ remaining_steps=(self.max_iterations - iteration),
+ ),
+ }
+ plan_update = self.llm_engine(
+ [plan_update_message] + agent_memory + [plan_update_message_user], stop_sequences=[""]
+ )
+
+ # Log final facts and plan
+ final_plan_redaction = PLAN_UPDATE_FINAL_PLAN_REDACTION.format(task=task, plan_update=plan_update)
+ final_facts_redaction = f"""Here is the updated list of the facts that I know:
+```
+{facts_update}
+```"""
+ self.logs.append({"plan": final_plan_redaction, "facts": final_facts_redaction})
+ self.logger.log(36, "===== Updated plan =====")
+ self.logger.log(35, final_plan_redaction)
+
+
+class ReactJsonAgent(ReactAgent):
+ """
+ This agent that solves the given task step by step, using the ReAct framework:
+ While the objective is not reached, the agent will perform a cycle of thinking and acting.
+ The tool calls will be formulated by the LLM in JSON format, then parsed and executed.
+ """
+
+ def __init__(
+ self,
+ tools: List[Tool],
+ llm_engine: Optional[Callable] = None,
+ system_prompt: Optional[str] = None,
+ tool_description_template: Optional[str] = None,
+ grammar: Optional[Dict[str, str]] = None,
+ planning_interval: Optional[int] = None,
+ **kwargs,
+ ):
+ if llm_engine is None:
+ llm_engine = HfApiEngine()
+ if system_prompt is None:
+ system_prompt = DEFAULT_REACT_JSON_SYSTEM_PROMPT
+ if tool_description_template is None:
+ tool_description_template = DEFAULT_TOOL_DESCRIPTION_TEMPLATE
+ super().__init__(
+ tools=tools,
+ llm_engine=llm_engine,
+ system_prompt=system_prompt,
+ tool_description_template=tool_description_template,
+ grammar=grammar,
+ planning_interval=planning_interval,
+ **kwargs,
+ )
+
+ def step(self, log_entry: Dict[str, Any]):
+ """
+ Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
+ The errors are raised here, they are caught and logged in the run() method.
+ """
+ agent_memory = self.write_inner_memory_from_logs()
+
+ self.prompt = agent_memory
+ self.logger.debug("===== New step =====")
+
+ # Add new step in logs
+ log_entry["agent_memory"] = agent_memory.copy()
+
+ self.logger.info("===== Calling LLM with this last message: =====")
+ self.logger.info(self.prompt[-1])
+
+ try:
+ additional_args = {"grammar": self.grammar} if self.grammar is not None else {}
+ llm_output = self.llm_engine(
+ self.prompt, stop_sequences=["", "Observation:"], **additional_args
+ )
+ except Exception as e:
+ raise AgentGenerationError(f"Error in generating llm output: {e}.")
+ self.logger.debug("===== Output message of the LLM: =====")
+ self.logger.debug(llm_output)
+ log_entry["llm_output"] = llm_output
+
+ # Parse
+ self.logger.debug("===== Extracting action =====")
+ rationale, action = self.extract_action(llm_output=llm_output, split_token="Action:")
+
+ try:
+ tool_name, arguments = self.tool_parser(action)
+ except Exception as e:
+ raise AgentParsingError(f"Could not parse the given action: {e}.")
+
+ log_entry["rationale"] = rationale
+ log_entry["tool_call"] = {"tool_name": tool_name, "tool_arguments": arguments}
+
+ # Execute
+ self.logger.warning("=== Agent thoughts:")
+ self.logger.log(31, rationale)
+ self.logger.warning(f">>> Calling tool: '{tool_name}' with arguments: {arguments}")
+ if tool_name == "final_answer":
+ if isinstance(arguments, dict):
+ if "answer" in arguments:
+ answer = arguments["answer"]
+ if (
+ isinstance(answer, str) and answer in self.state.keys()
+ ): # if the answer is a state variable, return the value
+ answer = self.state[answer]
+ else:
+ answer = arguments
+ else:
+ answer = arguments
+ log_entry["final_answer"] = answer
+ return answer
+ else:
+ if arguments is None:
+ arguments = {}
+ observation = self.execute_tool_call(tool_name, arguments)
+ observation_type = type(observation)
+ if observation_type in [AgentImage, AgentAudio]:
+ if observation_type == AgentImage:
+ observation_name = "image.png"
+ elif observation_type == AgentAudio:
+ observation_name = "audio.mp3"
+ # TODO: observation naming could allow for different names of same type
+
+ self.state[observation_name] = observation
+ updated_information = f"Stored '{observation_name}' in memory."
+ else:
+ updated_information = str(observation).strip()
+ self.logger.info(updated_information)
+ log_entry["observation"] = updated_information
+ return log_entry
+
+
+class ReactCodeAgent(ReactAgent):
+ """
+ This agent that solves the given task step by step, using the ReAct framework:
+ While the objective is not reached, the agent will perform a cycle of thinking and acting.
+ The tool calls will be formulated by the LLM in code format, then parsed and executed.
+ """
+
+ def __init__(
+ self,
+ tools: List[Tool],
+ llm_engine: Optional[Callable] = None,
+ system_prompt: Optional[str] = None,
+ tool_description_template: Optional[str] = None,
+ grammar: Optional[Dict[str, str]] = None,
+ additional_authorized_imports: Optional[List[str]] = None,
+ planning_interval: Optional[int] = None,
+ **kwargs,
+ ):
+ if llm_engine is None:
+ llm_engine = HfApiEngine()
+ if system_prompt is None:
+ system_prompt = DEFAULT_REACT_CODE_SYSTEM_PROMPT
+ if tool_description_template is None:
+ tool_description_template = DEFAULT_TOOL_DESCRIPTION_TEMPLATE
+ super().__init__(
+ tools=tools,
+ llm_engine=llm_engine,
+ system_prompt=system_prompt,
+ tool_description_template=tool_description_template,
+ grammar=grammar,
+ planning_interval=planning_interval,
+ **kwargs,
+ )
+
+ if not is_pygments_available():
+ transformers_logging.warning_once(
+ logger,
+ "pygments isn't installed. Installing pygments will enable color syntax highlighting in the "
+ "ReactCodeAgent.",
+ )
+
+ self.python_evaluator = evaluate_python_code
+ self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else []
+ self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports))
+ self.system_prompt = self.system_prompt.replace("<>", str(self.authorized_imports))
+ self.custom_tools = {}
+
+ def step(self, log_entry: Dict[str, Any]):
+ """
+ Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
+ The errors are raised here, they are caught and logged in the run() method.
+ """
+ agent_memory = self.write_inner_memory_from_logs()
+
+ self.prompt = agent_memory.copy()
+ self.logger.debug("===== New step =====")
+
+ # Add new step in logs
+ log_entry["agent_memory"] = agent_memory.copy()
+
+ self.logger.info("===== Calling LLM with these last messages: =====")
+ self.logger.info(self.prompt[-2:])
+
+ try:
+ additional_args = {"grammar": self.grammar} if self.grammar is not None else {}
+ llm_output = self.llm_engine(
+ self.prompt, stop_sequences=["", "Observation:"], **additional_args
+ )
+ except Exception as e:
+ raise AgentGenerationError(f"Error in generating llm output: {e}.")
+
+ self.logger.debug("=== Output message of the LLM:")
+ self.logger.debug(llm_output)
+ log_entry["llm_output"] = llm_output
+
+ # Parse
+ self.logger.debug("=== Extracting action ===")
+ try:
+ rationale, raw_code_action = self.extract_action(llm_output=llm_output, split_token="Code:")
+ except Exception as e:
+ self.logger.debug(f"Error in extracting action, trying to parse the whole output. Error trace: {e}")
+ rationale, raw_code_action = llm_output, llm_output
+
+ try:
+ code_action = parse_code_blob(raw_code_action)
+ except Exception as e:
+ error_msg = f"Error in code parsing: {e}. Make sure to provide correct code"
+ raise AgentParsingError(error_msg)
+
+ log_entry["rationale"] = rationale
+ log_entry["tool_call"] = {"tool_name": "code interpreter", "tool_arguments": code_action}
+
+ # Execute
+ self.log_rationale_code_action(rationale, code_action)
+ try:
+ static_tools = {
+ **BASE_PYTHON_TOOLS.copy(),
+ **self.toolbox.tools,
+ }
+ if self.managed_agents is not None:
+ static_tools = {**static_tools, **self.managed_agents}
+ result = self.python_evaluator(
+ code_action,
+ static_tools=static_tools,
+ custom_tools=self.custom_tools,
+ state=self.state,
+ authorized_imports=self.authorized_imports,
+ )
+ self.logger.warning("Print outputs:")
+ self.logger.log(32, self.state["print_outputs"])
+ observation = "Print outputs:\n" + self.state["print_outputs"]
+ if result is not None:
+ self.logger.warning("Last output from code snippet:")
+ self.logger.log(32, str(result))
+ observation += "Last output from code snippet:\n" + str(result)[:100000]
+ log_entry["observation"] = observation
+ except Exception as e:
+ error_msg = f"Code execution failed due to the following error:\n{str(e)}"
+ if "'dict' object has no attribute 'read'" in str(e):
+ error_msg += "\nYou get this error because you passed a dict as input for one of the arguments instead of a string."
+ raise AgentExecutionError(error_msg)
+ for line in code_action.split("\n"):
+ if line[: len("final_answer")] == "final_answer":
+ self.logger.log(33, "Final answer:")
+ self.logger.log(32, result)
+ log_entry["final_answer"] = result
+ return result
+
+
+LENGTH_TRUNCATE_REPORTS = 1000
+
+
+class ManagedAgent:
+ def __init__(self, agent, name, description, additional_prompting=None, provide_run_summary=False):
+ self.agent = agent
+ self.name = name
+ self.description = description
+ self.additional_prompting = additional_prompting
+ self.provide_run_summary = provide_run_summary
+
+ def write_full_task(self, task):
+ full_task = f"""You're a helpful agent named '{self.name}'.
+You have been submitted this task by your manager.
+---
+Task:
+{task}
+---
+You're helping your manager solve a wider task: so make sure to not provide a one-line answer, but give as much information as possible so that they have a clear understanding of the answer.
+
+Your final_answer WILL HAVE to contain these parts:
+### 1. Task outcome (short version):
+### 2. Task outcome (extremely detailed version):
+### 3. Additional context (if relevant):
+
+Put all these in your final_answer tool, everything that you do not pass as an argument to final_answer will be lost.
+And even if your task resolution is not successful, please return as much context as possible, so that your manager can act upon this feedback.
+<>"""
+ if self.additional_prompting:
+ full_task = full_task.replace("\n<>", self.additional_prompting).strip()
+ else:
+ full_task = full_task.replace("\n<>", "").strip()
+ return full_task
+
+ def __call__(self, request, **kwargs):
+ full_task = self.write_full_task(request)
+ output = self.agent.run(full_task, **kwargs)
+ if self.provide_run_summary:
+ answer = f"Here is the final answer from your managed agent '{self.name}':\n"
+ answer += str(output)
+ answer += f"\n\nFor more detail, find below a summary of this agent's work:\nSUMMARY OF WORK FROM AGENT '{self.name}':\n"
+ for message in self.agent.write_inner_memory_from_logs(summary_mode=True):
+ content = message["content"]
+ if len(str(content)) < LENGTH_TRUNCATE_REPORTS or "[FACTS LIST]" in str(content):
+ answer += "\n" + str(content) + "\n---"
+ else:
+ answer += (
+ "\n"
+ + str(content)[:LENGTH_TRUNCATE_REPORTS]
+ + "\n(...Step was truncated because too long)...\n---"
+ )
+ answer += f"\nEND OF SUMMARY OF WORK FROM AGENT '{self.name}'."
+ return answer
+ else:
+ return output
diff --git a/agents/default_tools.py b/agents/default_tools.py
new file mode 100644
index 0000000..3946aa9
--- /dev/null
+++ b/agents/default_tools.py
@@ -0,0 +1,187 @@
+#!/usr/bin/env python
+# coding=utf-8
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import importlib.util
+import json
+import math
+from dataclasses import dataclass
+from math import sqrt
+from typing import Dict
+
+from huggingface_hub import hf_hub_download, list_spaces
+
+from ..utils import is_offline_mode
+from .python_interpreter import LIST_SAFE_MODULES, evaluate_python_code
+from .tools import TOOL_CONFIG_FILE, TOOL_MAPPING, Tool
+
+
+def custom_print(*args):
+ return None
+
+
+BASE_PYTHON_TOOLS = {
+ "print": custom_print,
+ "isinstance": isinstance,
+ "range": range,
+ "float": float,
+ "int": int,
+ "bool": bool,
+ "str": str,
+ "set": set,
+ "list": list,
+ "dict": dict,
+ "tuple": tuple,
+ "round": round,
+ "ceil": math.ceil,
+ "floor": math.floor,
+ "log": math.log,
+ "exp": math.exp,
+ "sin": math.sin,
+ "cos": math.cos,
+ "tan": math.tan,
+ "asin": math.asin,
+ "acos": math.acos,
+ "atan": math.atan,
+ "atan2": math.atan2,
+ "degrees": math.degrees,
+ "radians": math.radians,
+ "pow": math.pow,
+ "sqrt": sqrt,
+ "len": len,
+ "sum": sum,
+ "max": max,
+ "min": min,
+ "abs": abs,
+ "enumerate": enumerate,
+ "zip": zip,
+ "reversed": reversed,
+ "sorted": sorted,
+ "all": all,
+ "any": any,
+ "map": map,
+ "filter": filter,
+ "ord": ord,
+ "chr": chr,
+ "next": next,
+ "iter": iter,
+ "divmod": divmod,
+ "callable": callable,
+ "getattr": getattr,
+ "hasattr": hasattr,
+ "setattr": setattr,
+ "issubclass": issubclass,
+ "type": type,
+}
+
+
+@dataclass
+class PreTool:
+ name: str
+ inputs: Dict[str, str]
+ output_type: type
+ task: str
+ description: str
+ repo_id: str
+
+
+HUGGINGFACE_DEFAULT_TOOLS_FROM_HUB = [
+ "image-transformation",
+ "text-to-image",
+]
+
+
+def get_remote_tools(logger, organization="huggingface-tools"):
+ if is_offline_mode():
+ logger.info("You are in offline mode, so remote tools are not available.")
+ return {}
+
+ spaces = list_spaces(author=organization)
+ tools = {}
+ for space_info in spaces:
+ repo_id = space_info.id
+ resolved_config_file = hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="space")
+ with open(resolved_config_file, encoding="utf-8") as reader:
+ config = json.load(reader)
+ task = repo_id.split("/")[-1]
+ tools[config["name"]] = PreTool(
+ task=task,
+ description=config["description"],
+ repo_id=repo_id,
+ name=task,
+ inputs=config["inputs"],
+ output_type=config["output_type"],
+ )
+
+ return tools
+
+
+def setup_default_tools(logger):
+ default_tools = {}
+ main_module = importlib.import_module("transformers")
+ tools_module = main_module.agents
+
+ for task_name, tool_class_name in TOOL_MAPPING.items():
+ tool_class = getattr(tools_module, tool_class_name)
+ tool_instance = tool_class()
+ default_tools[tool_class.name] = PreTool(
+ name=tool_instance.name,
+ inputs=tool_instance.inputs,
+ output_type=tool_instance.output_type,
+ task=task_name,
+ description=tool_instance.description,
+ repo_id=None,
+ )
+
+ return default_tools
+
+
+class PythonInterpreterTool(Tool):
+ name = "python_interpreter"
+ description = "This is a tool that evaluates python code. It can be used to perform calculations."
+
+ output_type = "string"
+
+ def __init__(self, *args, authorized_imports=None, **kwargs):
+ if authorized_imports is None:
+ self.authorized_imports = list(set(LIST_SAFE_MODULES))
+ else:
+ self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(authorized_imports))
+ self.inputs = {
+ "code": {
+ "type": "string",
+ "description": (
+ "The code snippet to evaluate. All variables used in this snippet must be defined in this same snippet, "
+ f"else you will get an error. This code can only import the following python libraries: {authorized_imports}."
+ ),
+ }
+ }
+ super().__init__(*args, **kwargs)
+
+ def forward(self, code):
+ output = str(
+ evaluate_python_code(code, static_tools=BASE_PYTHON_TOOLS, authorized_imports=self.authorized_imports)
+ )
+ return output
+
+
+class FinalAnswerTool(Tool):
+ name = "final_answer"
+ description = "Provides a final answer to the given problem."
+ inputs = {"answer": {"type": "any", "description": "The final answer to the problem"}}
+ output_type = "any"
+
+ def forward(self, answer):
+ return answer
diff --git a/agents/document_question_answering.py b/agents/document_question_answering.py
new file mode 100644
index 0000000..23ae5b0
--- /dev/null
+++ b/agents/document_question_answering.py
@@ -0,0 +1,89 @@
+#!/usr/bin/env python
+# coding=utf-8
+
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import re
+
+import numpy as np
+import torch
+
+from ..models.auto import AutoProcessor
+from ..models.vision_encoder_decoder import VisionEncoderDecoderModel
+from ..utils import is_vision_available
+from .tools import PipelineTool
+
+
+if is_vision_available():
+ from PIL import Image
+
+
+class DocumentQuestionAnsweringTool(PipelineTool):
+ default_checkpoint = "naver-clova-ix/donut-base-finetuned-docvqa"
+ description = "This is a tool that answers a question about an document (pdf). It returns a string that contains the answer to the question."
+ name = "document_qa"
+ pre_processor_class = AutoProcessor
+ model_class = VisionEncoderDecoderModel
+
+ inputs = {
+ "document": {
+ "type": "image",
+ "description": "The image containing the information. Can be a PIL Image or a string path to the image.",
+ },
+ "question": {"type": "string", "description": "The question in English"},
+ }
+ output_type = "string"
+
+ def __init__(self, *args, **kwargs):
+ if not is_vision_available():
+ raise ValueError("Pillow must be installed to use the DocumentQuestionAnsweringTool.")
+
+ super().__init__(*args, **kwargs)
+
+ def encode(self, document: "Image", question: str):
+ task_prompt = "{user_input} "
+ prompt = task_prompt.replace("{user_input}", question)
+ decoder_input_ids = self.pre_processor.tokenizer(
+ prompt, add_special_tokens=False, return_tensors="pt"
+ ).input_ids
+ if isinstance(document, str):
+ img = Image.open(document).convert("RGB")
+ img_array = np.array(img).transpose(2, 0, 1)
+ document = torch.from_numpy(img_array)
+ pixel_values = self.pre_processor(document, return_tensors="pt").pixel_values
+
+ return {"decoder_input_ids": decoder_input_ids, "pixel_values": pixel_values}
+
+ def forward(self, inputs):
+ return self.model.generate(
+ inputs["pixel_values"].to(self.device),
+ decoder_input_ids=inputs["decoder_input_ids"].to(self.device),
+ max_length=self.model.decoder.config.max_position_embeddings,
+ early_stopping=True,
+ pad_token_id=self.pre_processor.tokenizer.pad_token_id,
+ eos_token_id=self.pre_processor.tokenizer.eos_token_id,
+ use_cache=True,
+ num_beams=1,
+ bad_words_ids=[[self.pre_processor.tokenizer.unk_token_id]],
+ return_dict_in_generate=True,
+ ).sequences
+
+ def decode(self, outputs):
+ sequence = self.pre_processor.batch_decode(outputs)[0]
+ sequence = sequence.replace(self.pre_processor.tokenizer.eos_token, "")
+ sequence = sequence.replace(self.pre_processor.tokenizer.pad_token, "")
+ sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
+ sequence = self.pre_processor.token2json(sequence)
+
+ return sequence["answer"]
diff --git a/agents/evaluate_agent.py b/agents/evaluate_agent.py
new file mode 100644
index 0000000..90dfd4f
--- /dev/null
+++ b/agents/evaluate_agent.py
@@ -0,0 +1,414 @@
+#!/usr/bin/env python
+# coding=utf-8
+
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from .agents import BASE_PYTHON_TOOLS
+from .python_interpreter import InterpreterError, evaluate
+
+
+### Fake tools for test
+def classifier(text, labels):
+ return f"This is the classification of {text} along {labels}."
+
+
+def translator(text, src_lang, tgt_lang):
+ return f"This is the translation of {text} from {src_lang} to {tgt_lang}."
+
+
+def speaker(text):
+ return f"This is actually a sound reading {text}."
+
+
+def transcriber(audio):
+ if "sound" not in audio:
+ raise ValueError(f"`audio` ({audio}) is not a sound.")
+ return f"This is the transcribed text from {audio}."
+
+
+def image_generator(prompt):
+ return f"This is actually an image representing {prompt}."
+
+
+def image_captioner(image):
+ if "image" not in image:
+ raise ValueError(f"`image` ({image}) is not an image.")
+ return f"This is a description of {image}."
+
+
+def image_transformer(image, prompt):
+ if "image" not in image:
+ raise ValueError(f"`image` ({image}) is not an image.")
+ return f"This is a transformation of {image} according to {prompt}."
+
+
+def question_answerer(text, question):
+ return f"This is the answer to {question} from {text}."
+
+
+def image_qa(image, question):
+ if "image" not in image:
+ raise ValueError(f"`image` ({image}) is not an image.")
+ return f"This is the answer to {question} from {image}."
+
+
+def text_downloader(url):
+ return f"This is the content of {url}."
+
+
+def summarizer(text):
+ return f"This is a summary of {text}."
+
+
+def video_generator(prompt, seconds=2):
+ return f"A video of {prompt}"
+
+
+def document_qa(image, question):
+ return f"This is the answer to {question} from the document {image}."
+
+
+def image_segmenter(image, prompt):
+ return f"This is the mask of {prompt} in {image}"
+
+
+TEST_TOOLS = {
+ "text_classifier": classifier,
+ "translator": translator,
+ "text_reader": speaker,
+ "summarizer": summarizer,
+ "transcriber": transcriber,
+ "image_generator": image_generator,
+ "image_captioner": image_captioner,
+ "image_transformer": image_transformer,
+ "text_qa": question_answerer,
+ "text_downloader": text_downloader,
+ "image_qa": image_qa,
+ "video_generator": video_generator,
+ "document_qa": document_qa,
+ "image_segmenter": image_segmenter,
+}
+
+
+class Problem:
+ """
+ A class regrouping all the information to solve a problem on which we will evaluate agents.
+
+ Args:
+ task (`str` ou `list[str]`):
+ One or several descriptions of the task to perform. If a list, it should contain variations on the
+ phrasing, but for the same task.
+ inputs (`list[str]` or `dict[str, str]`):
+ The inputs that will be fed to the tools. For this testing environment, only strings are accepted as
+ values. Pass along a dictionary when you want to specify the values of each inputs, or just the list of
+ inputs expected (the value used will be `<>` in this case).
+ answer (`str` or `list[str]`):
+ The theoretical answer (or list of possible valid answers) to the problem, as code.
+ """
+
+ def __init__(self, task, inputs, answer):
+ self.task = task
+ self.inputs = inputs
+ self.answer = answer
+
+
+### The list of problems the agent will be evaluated on.
+EVALUATION_TASKS = [
+ Problem(
+ task=[
+ "Is the following `text` (in Spanish) positive or negative?",
+ "Is the text in the variable `text` (in Spanish) positive or negative?",
+ "Translate the following `text` from Spanish to English then tell me if its positive or negative.",
+ ],
+ inputs=["text"],
+ answer="""text_classifier(translator(text, src_lang="Spanish", tgt_lang="English"), labels=["positive", "negative"])""",
+ ),
+ Problem(
+ task=[
+ "Tell me out loud what the `image` contains.",
+ "Describe the following `image` out loud.",
+ "Find what is in the picture stored in `image` then read it out loud.",
+ ],
+ inputs=["image"],
+ answer=[
+ "text_reader(image_captioner(image))",
+ "text_reader(image_qa(image, question='What is in the image?'))",
+ ],
+ ),
+ Problem(
+ task=[
+ "Generate an image from the text given in `text_input`. Then transform it according to the text in `prompt`.",
+ "Use the following `text_input` to generate an image, then transform it by using the text in `prompt`.",
+ ],
+ inputs=["text_input", "prompt"],
+ answer="image_transformer(image_generator(text_input), prompt)",
+ ),
+ Problem(
+ task=[
+ "Download the content of `url`, summarize it then generate an image from its content.",
+ "Use a summary of the web page at `url` to generate an image.",
+ "Summarize the content of the web page at `url`, and use the result to generate an image.",
+ ],
+ inputs=["url"],
+ answer="image_generator(summarizer(text_downloader(url)))",
+ ),
+ Problem(
+ task=[
+ "Transform the following `image` using the prompt in `text`. The prompt is in Spanish.",
+ "Use the text prompt in `text` (in Spanish) to transform the following `image`.",
+ "Translate the `text` from Spanish to English then use it to transform the picture in `image`.",
+ ],
+ inputs=["text", "image"],
+ answer="image_transformer(image, translator(text, src_lang='Spanish', tgt_lang='English'))",
+ ),
+ Problem(
+ task=[
+ "Download the content of `url`, summarize it then read it out loud to me.",
+ "Read me a summary of the web page at `url`.",
+ ],
+ inputs=["url"],
+ answer="text_reader(summarizer(text_downloader(url)))",
+ ),
+ Problem(
+ task=[
+ "Generate an image from the text given in `text_input`.",
+ ],
+ inputs=["text_input"],
+ answer="image_generator(text_input)",
+ ),
+ Problem(
+ task=[
+ "Replace the beaver in the `image` by the `prompt`.",
+ "Transform the `image` so that it contains the `prompt`.",
+ "Use `prompt` to transform this `image`.",
+ ],
+ inputs=["image", "prompt"],
+ answer="image_transformer(image, prompt)",
+ ),
+ Problem(
+ task=[
+ "Provide me the summary of the `text`, then read it to me before transcribing it and translating it in French.",
+ "Summarize `text`, read it out loud then transcribe the audio and translate it in French.",
+ "Read me a summary of the `text` out loud. Transcribe this and translate it in French.",
+ ],
+ inputs=["text"],
+ answer="translator(transcriber(text_reader(summarizer(text))), src_lang='English', tgt_lang='French')",
+ ),
+ Problem(
+ task=["Generate a video of the `prompt`", "Animate a `prompt`", "Make me a short video using `prompt`."],
+ inputs={"prompt": "A lobster swimming"},
+ answer="video_generator('A lobster swimming')",
+ ),
+ Problem(
+ task=[
+ "Download the following file `url`, summarize it in a few words and generate a video from it."
+ "Fetch the file at this `url`, summarize it, and create an animation out of it."
+ ],
+ inputs=["url"],
+ answer="video_generator(summarizer(text_downloader(url)))",
+ ),
+]
+
+
+def get_theoretical_tools(agent_answer, theoretical_answer, code_answer):
+ if not isinstance(theoretical_answer, list):
+ return {name for name in TEST_TOOLS if name in code_answer}
+
+ if isinstance(agent_answer, dict):
+ for one_answer, one_code in zip(theoretical_answer, code_answer):
+ if one_answer in agent_answer.values():
+ return {name for name in TEST_TOOLS if name in one_code}
+
+ for one_answer, one_code in zip(theoretical_answer, code_answer):
+ if agent_answer == one_answer:
+ return {name for name in TEST_TOOLS if name in one_code}
+
+ return {name for name in TEST_TOOLS if name in code_answer[0]}
+
+
+def evaluate_code(code, inputs=None, state=None, verbose=False, return_interpretor_error=False):
+ tools = BASE_PYTHON_TOOLS.copy()
+ for name, tool in TEST_TOOLS.items():
+ if name not in code:
+ continue
+ tools[name] = tool
+
+ if isinstance(inputs, dict):
+ inputs = inputs.copy()
+ elif inputs is not None:
+ inputs = {inp: f"<<{inp}>>" for inp in inputs}
+
+ if state is not None:
+ state.update(inputs)
+ else:
+ state = inputs
+
+ try:
+ return evaluate(code, tools, state)
+ except InterpreterError as e:
+ return str(e)
+ except Exception as e:
+ if verbose:
+ print(e)
+ return None
+
+
+def score_code(agent_answer, theoretical_answer, verbose: bool = False):
+ if verbose:
+ print(agent_answer, theoretical_answer)
+ theoretical_answer = theoretical_answer if isinstance(theoretical_answer, list) else [theoretical_answer]
+
+ if agent_answer in theoretical_answer:
+ if verbose:
+ print("Perfect!")
+ return 1
+ elif isinstance(agent_answer, dict) and any(v in theoretical_answer for v in agent_answer.values()):
+ if verbose:
+ print("Almsot perfect, result in state!")
+ return 0.75
+ else:
+ if verbose:
+ print("Result is not the right one but code executed.")
+ return 0.3
+
+
+def evaluate_one_result(code, agent_answer, theoretical_answer, answer, verbose=False):
+ tools_in_code = {name for name in TEST_TOOLS if f"`{name}`" in code}
+ theoretical_tools = get_theoretical_tools(agent_answer, theoretical_answer, answer)
+ if tools_in_code == theoretical_tools:
+ tool_selection_score = 1.0
+ tool_selection_errors = None
+ else:
+ missing_tools = len(theoretical_tools - tools_in_code)
+ unexpected_tools = len(tools_in_code - theoretical_tools)
+ tool_selection_score = max(0, 1.0 - 0.25 * missing_tools - 0.25 * unexpected_tools)
+
+ tool_selection_errors = {
+ "selected_tools": tools_in_code,
+ "theoretical_tools": theoretical_tools,
+ }
+
+ tools_in_code = {name for name in TEST_TOOLS if name in code}
+ if tools_in_code == theoretical_tools:
+ tool_used_score = 1.0
+ tool_used_errors = None
+ else:
+ missing_tools = len(theoretical_tools - tools_in_code)
+ unexpected_tools = len(tools_in_code - theoretical_tools)
+ tool_used_score = max(0, 1.0 - 0.25 * missing_tools - 0.25 * unexpected_tools)
+
+ tool_used_errors = {
+ "selected_tools": tools_in_code,
+ "theoretical_tools": theoretical_tools,
+ }
+
+ score = score_code(agent_answer, theoretical_answer, verbose=verbose)
+ if score < 1.0:
+ code_errors = {
+ "code_produced": code,
+ "evaluation": agent_answer,
+ "theoretical_answer": theoretical_answer,
+ }
+ else:
+ code_errors = None
+
+ return (tool_selection_score, tool_used_score, score), (tool_selection_errors, tool_used_errors, code_errors)
+
+
+def evaluate_agent(agent, batch_size=8, verbose=False, return_errors=False):
+ """
+ Evaluates a new agent on all `EVALUATION_TASKS`.
+
+ Example:
+
+ ```py
+ agent = NewOpenAiAgent(model="text-davinci-003", api_key=your_api_key)
+ bads = new_evaluate_agent(agent)
+ for bad in bads:
+ print(bad)
+ ```
+ """
+ # Sanity check
+ agent_tools = set(agent.toolbox.keys())
+ if agent_tools != set(TEST_TOOLS):
+ missing_tools = set(TEST_TOOLS) - agent_tools
+ unexpected_tools = set(agent_tools) - TEST_TOOLS
+ raise ValueError(
+ f"Fix the test tools in the evaluate_agent module. Tools mising: {missing_tools}. Extra tools: {unexpected_tools}."
+ )
+
+ eval_tasks = []
+ eval_idx = []
+ for idx, pb in enumerate(EVALUATION_TASKS):
+ if isinstance(pb.task, list):
+ eval_tasks.extend(pb.task)
+ eval_idx.extend([idx] * len(pb.task))
+ else:
+ eval_tasks.append(pb.task)
+ eval_idx.append(idx)
+
+ tool_selection_score = 0
+ tool_used_score = 0
+ code_score = 0
+
+ if return_errors:
+ tool_selection_errors = {}
+ tool_used_errors = {}
+ code_errors = {}
+
+ for start_idx in range(0, len(eval_tasks), batch_size):
+ end_idx = min(start_idx + batch_size, len(eval_tasks))
+ batch_tasks = eval_tasks[start_idx:end_idx]
+
+ results = [agent.run(task, return_generated_code=True) for task in batch_tasks]
+
+ for idx, result in enumerate(results):
+ problem = EVALUATION_TASKS[eval_idx[start_idx + idx]]
+ if verbose:
+ print(f"====Task {start_idx + idx}====\n{batch_tasks[idx]}\n")
+ code = agent.extract_action(result, split_token="Answer:")
+
+ # Evaluate agent answer and code answer
+ agent_answer = evaluate_code(code, problem.inputs, verbose=verbose)
+ if isinstance(problem.answer, list):
+ theoretical_answer = [evaluate_code(answer, problem.inputs) for answer in problem.answer]
+ else:
+ theoretical_answer = evaluate_code(problem.answer, problem.inputs)
+
+ scores, errors = evaluate_one_result(
+ code, agent_answer, theoretical_answer, problem.answer, verbose=verbose
+ )
+
+ tool_selection_score += scores[0]
+ tool_used_score += scores[1]
+ code_score += scores[2]
+
+ if return_errors:
+ if errors[0] is not None:
+ tool_selection_errors[batch_tasks[idx]] = errors[0]
+ if errors[1] is not None:
+ tool_used_errors[batch_tasks[idx]] = errors[1]
+ if errors[2] is not None:
+ code_errors[batch_tasks[idx]] = errors[2]
+
+ scores = {
+ "tool selection score": 100 * (tool_selection_score / len(eval_tasks)),
+ "tool used score": 100 * (tool_used_score / len(eval_tasks)),
+ "code score": 100 * (code_score / len(eval_tasks)),
+ }
+
+ if return_errors:
+ return scores, tool_selection_errors, tool_used_errors, code_errors
+ else:
+ return scores
diff --git a/agents/image_question_answering.py b/agents/image_question_answering.py
new file mode 100644
index 0000000..de0efb7
--- /dev/null
+++ b/agents/image_question_answering.py
@@ -0,0 +1,58 @@
+#!/usr/bin/env python
+# coding=utf-8
+
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+from PIL import Image
+
+from ..models.auto import AutoModelForVisualQuestionAnswering, AutoProcessor
+from ..utils import requires_backends
+from .tools import PipelineTool
+
+
+class ImageQuestionAnsweringTool(PipelineTool):
+ default_checkpoint = "dandelin/vilt-b32-finetuned-vqa"
+ description = (
+ "This is a tool that answers a question about an image. It "
+ "returns a text that is the answer to the question."
+ )
+ name = "image_qa"
+ pre_processor_class = AutoProcessor
+ model_class = AutoModelForVisualQuestionAnswering
+
+ inputs = {
+ "image": {
+ "type": "image",
+ "description": "The image containing the information. Can be a PIL Image or a string path to the image.",
+ },
+ "question": {"type": "string", "description": "The question in English"},
+ }
+ output_type = "string"
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["vision"])
+ super().__init__(*args, **kwargs)
+
+ def encode(self, image: "Image", question: str):
+ return self.pre_processor(image, question, return_tensors="pt")
+
+ def forward(self, inputs):
+ with torch.no_grad():
+ return self.model(**inputs).logits
+
+ def decode(self, outputs):
+ idx = outputs.argmax(-1).item()
+ return self.model.config.id2label[idx]
diff --git a/agents/llm_engine.py b/agents/llm_engine.py
new file mode 100644
index 0000000..afa4d62
--- /dev/null
+++ b/agents/llm_engine.py
@@ -0,0 +1,238 @@
+#!/usr/bin/env python
+# coding=utf-8
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from copy import deepcopy
+from enum import Enum
+from typing import Dict, List, Optional
+
+from huggingface_hub import InferenceClient
+
+from .. import AutoTokenizer
+from ..pipelines.base import Pipeline
+from ..utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class MessageRole(str, Enum):
+ USER = "user"
+ ASSISTANT = "assistant"
+ SYSTEM = "system"
+ TOOL_CALL = "tool-call"
+ TOOL_RESPONSE = "tool-response"
+
+ @classmethod
+ def roles(cls):
+ return [r.value for r in cls]
+
+
+def get_clean_message_list(message_list: List[Dict[str, str]], role_conversions: Dict[str, str] = {}):
+ """
+ Subsequent messages with the same role will be concatenated to a single message.
+
+ Args:
+ message_list (`List[Dict[str, str]]`): List of chat messages.
+ """
+ final_message_list = []
+ message_list = deepcopy(message_list) # Avoid modifying the original list
+ for message in message_list:
+ if not set(message.keys()) == {"role", "content"}:
+ raise ValueError("Message should contain only 'role' and 'content' keys!")
+
+ role = message["role"]
+ if role not in MessageRole.roles():
+ raise ValueError(f"Incorrect role {role}, only {MessageRole.roles()} are supported for now.")
+
+ if role in role_conversions:
+ message["role"] = role_conversions[role]
+
+ if len(final_message_list) > 0 and message["role"] == final_message_list[-1]["role"]:
+ final_message_list[-1]["content"] += "\n=======\n" + message["content"]
+ else:
+ final_message_list.append(message)
+ return final_message_list
+
+
+llama_role_conversions = {
+ MessageRole.TOOL_RESPONSE: MessageRole.USER,
+}
+
+
+class HfEngine:
+ def __init__(self, model_id: Optional[str] = None):
+ self.last_input_token_count = None
+ self.last_output_token_count = None
+ if model_id is None:
+ model_id = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
+ logger.warning(f"Using default model for token counting: '{model_id}'")
+ try:
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id)
+ except Exception as e:
+ logger.warning(f"Failed to load tokenizer for model {model_id}: {e}. Loading default tokenizer instead.")
+ self.tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-1.7B-Instruct")
+
+ def get_token_counts(self):
+ return {
+ "input_token_count": self.last_input_token_count,
+ "output_token_count": self.last_output_token_count,
+ }
+
+ def generate(
+ self, messages: List[Dict[str, str]], stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None
+ ):
+ raise NotImplementedError
+
+ def __call__(
+ self, messages: List[Dict[str, str]], stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None
+ ) -> str:
+ """Process the input messages and return the model's response.
+
+ This method sends a list of messages to the Hugging Face Inference API, optionally with stop sequences and grammar customization.
+
+ Parameters:
+ messages (`List[Dict[str, str]]`):
+ A list of message dictionaries to be processed. Each dictionary should have the structure `{"role": "user/system", "content": "message content"}`.
+ stop_sequences (`List[str]`, *optional*):
+ A list of strings that will stop the generation if encountered in the model's output.
+ grammar (`str`, *optional*):
+ The grammar or formatting structure to use in the model's response.
+
+ Returns:
+ `str`: The text content of the model's response.
+
+ Example:
+ ```python
+ >>> engine = HfApiEngine(
+ ... model="meta-llama/Meta-Llama-3.1-8B-Instruct",
+ ... token="your_hf_token_here",
+ ... max_tokens=2000
+ ... )
+ >>> messages = [{"role": "user", "content": "Explain quantum mechanics in simple terms."}]
+ >>> response = engine(messages, stop_sequences=["END"])
+ >>> print(response)
+ "Quantum mechanics is the branch of physics that studies..."
+ ```
+ """
+ if not isinstance(messages, List):
+ raise ValueError("Messages should be a list of dictionaries with 'role' and 'content' keys.")
+ if stop_sequences is None:
+ stop_sequences = []
+ response = self.generate(messages, stop_sequences, grammar)
+ self.last_input_token_count = len(self.tokenizer.apply_chat_template(messages, tokenize=True))
+ self.last_output_token_count = len(self.tokenizer.encode(response))
+
+ # Remove stop sequences from LLM output
+ for stop_seq in stop_sequences:
+ if response[-len(stop_seq) :] == stop_seq:
+ response = response[: -len(stop_seq)]
+ return response
+
+
+class HfApiEngine(HfEngine):
+ """A class to interact with Hugging Face's Inference API for language model interaction.
+
+ This engine allows you to communicate with Hugging Face's models using the Inference API. It can be used in both serverless mode or with a dedicated endpoint, supporting features like stop sequences and grammar customization.
+
+ Parameters:
+ model (`str`, *optional*, defaults to `"meta-llama/Meta-Llama-3.1-8B-Instruct"`):
+ The Hugging Face model ID to be used for inference. This can be a path or model identifier from the Hugging Face model hub.
+ token (`str`, *optional*):
+ Token used by the Hugging Face API for authentication.
+ If not provided, the class will use the token stored in the Hugging Face CLI configuration.
+ max_tokens (`int`, *optional*, defaults to 1500):
+ The maximum number of tokens allowed in the output.
+ timeout (`int`, *optional*, defaults to 120):
+ Timeout for the API request, in seconds.
+
+ Raises:
+ ValueError:
+ If the model name is not provided.
+ """
+
+ def __init__(
+ self,
+ model: str = "meta-llama/Meta-Llama-3.1-8B-Instruct",
+ token: Optional[str] = None,
+ max_tokens: Optional[int] = 1500,
+ timeout: Optional[int] = 120,
+ ):
+ super().__init__(model_id=model)
+ self.model = model
+ self.client = InferenceClient(self.model, token=token, timeout=timeout)
+ self.max_tokens = max_tokens
+
+ def generate(
+ self, messages: List[Dict[str, str]], stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None
+ ) -> str:
+ # Get clean message list
+ messages = get_clean_message_list(messages, role_conversions=llama_role_conversions)
+
+ # Send messages to the Hugging Face Inference API
+ if grammar is not None:
+ response = self.client.chat_completion(
+ messages, stop=stop_sequences, max_tokens=self.max_tokens, response_format=grammar
+ )
+ else:
+ response = self.client.chat_completion(messages, stop=stop_sequences, max_tokens=self.max_tokens)
+
+ response = response.choices[0].message.content
+ return response
+
+
+class TransformersEngine(HfEngine):
+ """This engine uses a pre-initialized local text-generation pipeline."""
+
+ def __init__(self, pipeline: Pipeline, model_id: Optional[str] = None):
+ super().__init__(model_id)
+ self.pipeline = pipeline
+
+ def generate(
+ self,
+ messages: List[Dict[str, str]],
+ stop_sequences: Optional[List[str]] = None,
+ grammar: Optional[str] = None,
+ max_length: int = 1500,
+ ) -> str:
+ # Get clean message list
+ messages = get_clean_message_list(messages, role_conversions=llama_role_conversions)
+
+ # Get LLM output
+ if stop_sequences is not None and len(stop_sequences) > 0:
+ stop_strings = stop_sequences
+ else:
+ stop_strings = None
+
+ output = self.pipeline(
+ messages,
+ stop_strings=stop_strings,
+ max_length=max_length,
+ tokenizer=self.pipeline.tokenizer,
+ )
+
+ response = output[0]["generated_text"][-1]["content"]
+ return response
+
+
+DEFAULT_JSONAGENT_REGEX_GRAMMAR = {
+ "type": "regex",
+ "value": 'Thought: .+?\\nAction:\\n\\{\\n\\s{4}"action":\\s"[^"\\n]+",\\n\\s{4}"action_input":\\s"[^"\\n]+"\\n\\}\\n',
+}
+
+DEFAULT_CODEAGENT_REGEX_GRAMMAR = {
+ "type": "regex",
+ "value": "Thought: .+?\\nCode:\\n```(?:py|python)?\\n(?:.|\\s)+?\\n```",
+}
diff --git a/agents/monitoring.py b/agents/monitoring.py
new file mode 100644
index 0000000..7126e72
--- /dev/null
+++ b/agents/monitoring.py
@@ -0,0 +1,117 @@
+#!/usr/bin/env python
+# coding=utf-8
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from ..utils import logging
+from .agent_types import AgentAudio, AgentImage, AgentText
+
+
+logger = logging.get_logger(__name__)
+
+
+def pull_message(step_log: dict, test_mode: bool = True):
+ try:
+ from gradio import ChatMessage
+ except ImportError:
+ if test_mode:
+
+ class ChatMessage:
+ def __init__(self, role, content, metadata=None):
+ self.role = role
+ self.content = content
+ self.metadata = metadata
+ else:
+ raise ImportError("Gradio should be installed in order to launch a gradio demo.")
+
+ if step_log.get("rationale"):
+ yield ChatMessage(role="assistant", content=step_log["rationale"])
+ if step_log.get("tool_call"):
+ used_code = step_log["tool_call"]["tool_name"] == "code interpreter"
+ content = step_log["tool_call"]["tool_arguments"]
+ if used_code:
+ content = f"```py\n{content}\n```"
+ yield ChatMessage(
+ role="assistant",
+ metadata={"title": f"๐ ๏ธ Used tool {step_log['tool_call']['tool_name']}"},
+ content=str(content),
+ )
+ if step_log.get("observation"):
+ yield ChatMessage(role="assistant", content=f"```\n{step_log['observation']}\n```")
+ if step_log.get("error"):
+ yield ChatMessage(
+ role="assistant",
+ content=str(step_log["error"]),
+ metadata={"title": "๐ฅ Error"},
+ )
+
+
+def stream_to_gradio(agent, task: str, test_mode: bool = False, **kwargs):
+ """Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages."""
+
+ try:
+ from gradio import ChatMessage
+ except ImportError:
+ if test_mode:
+
+ class ChatMessage:
+ def __init__(self, role, content, metadata=None):
+ self.role = role
+ self.content = content
+ self.metadata = metadata
+ else:
+ raise ImportError("Gradio should be installed in order to launch a gradio demo.")
+
+ for step_log in agent.run(task, stream=True, **kwargs):
+ if isinstance(step_log, dict):
+ for message in pull_message(step_log, test_mode=test_mode):
+ yield message
+
+ final_answer = step_log # Last log is the run's final_answer
+
+ if isinstance(final_answer, AgentText):
+ yield ChatMessage(role="assistant", content=f"**Final answer:**\n```\n{final_answer.to_string()}\n```")
+ elif isinstance(final_answer, AgentImage):
+ yield ChatMessage(
+ role="assistant",
+ content={"path": final_answer.to_string(), "mime_type": "image/png"},
+ )
+ elif isinstance(final_answer, AgentAudio):
+ yield ChatMessage(
+ role="assistant",
+ content={"path": final_answer.to_string(), "mime_type": "audio/wav"},
+ )
+ else:
+ yield ChatMessage(role="assistant", content=str(final_answer))
+
+
+class Monitor:
+ def __init__(self, tracked_llm_engine):
+ self.step_durations = []
+ self.tracked_llm_engine = tracked_llm_engine
+ if getattr(self.tracked_llm_engine, "last_input_token_count", "Not found") != "Not found":
+ self.total_input_token_count = 0
+ self.total_output_token_count = 0
+
+ def update_metrics(self, step_log):
+ step_duration = step_log["step_duration"]
+ self.step_durations.append(step_duration)
+ logger.info(f"Step {len(self.step_durations)}:")
+ logger.info(f"- Time taken: {step_duration:.2f} seconds (valid only if step succeeded)")
+
+ if getattr(self.tracked_llm_engine, "last_input_token_count", None) is not None:
+ self.total_input_token_count += self.tracked_llm_engine.last_input_token_count
+ self.total_output_token_count += self.tracked_llm_engine.last_output_token_count
+ logger.info(f"- Input tokens: {self.total_input_token_count}")
+ logger.info(f"- Output tokens: {self.total_output_token_count}")
diff --git a/agents/prompts.py b/agents/prompts.py
new file mode 100644
index 0000000..7a84b1d
--- /dev/null
+++ b/agents/prompts.py
@@ -0,0 +1,789 @@
+#!/usr/bin/env python
+# coding=utf-8
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import re
+
+from ..utils import cached_file
+
+
+# docstyle-ignore
+CHAT_MESSAGE_PROMPT = """
+Human: <>
+
+Assistant: """
+
+
+DEFAULT_PROMPTS_REPO = "huggingface-tools/default-prompts"
+PROMPT_FILES = {"chat": "chat_prompt_template.txt", "run": "run_prompt_template.txt"}
+
+
+def download_prompt(prompt_or_repo_id, agent_name, mode="run"):
+ """
+ Downloads and caches the prompt from a repo and returns it contents (if necessary).
+ """
+ if prompt_or_repo_id is None:
+ prompt_or_repo_id = DEFAULT_PROMPTS_REPO
+
+ # prompt is considered a repo ID when it does not contain any kind of space
+ if re.search("\\s", prompt_or_repo_id) is not None:
+ return prompt_or_repo_id
+
+ prompt_file = cached_file(
+ prompt_or_repo_id, PROMPT_FILES[mode], repo_type="dataset", user_agent={"agent": agent_name}
+ )
+ with open(prompt_file, "r", encoding="utf-8") as f:
+ return f.read()
+
+
+DEFAULT_CODE_SYSTEM_PROMPT = """You will be given a task to solve, your job is to come up with a series of simple commands in Python that will perform the task.
+To help you, I will give you access to a set of tools that you can use. Each tool is a Python function and has a description explaining the task it performs, the inputs it expects and the outputs it returns.
+You should first explain which tool you will use to perform the task and for what reason, then write the code in Python.
+Each instruction in Python should be a simple assignment. You can print intermediate results if it makes sense to do so.
+In the end, use tool 'final_answer' to return your answer, its argument will be what gets returned.
+You can use imports in your code, but only from the following list of modules: <>
+Be sure to provide a 'Code:' token, else the run will fail.
+
+Tools:
+<>
+
+Examples:
+---
+Task: "Answer the question in the variable `question` about the image stored in the variable `image`. The question is in French."
+
+Thought: I will use the following tools: `translator` to translate the question into English and then `image_qa` to answer the question on the input image.
+Code:
+```py
+translated_question = translator(question=question, src_lang="French", tgt_lang="English")
+print(f"The translated question is {translated_question}.")
+answer = image_qa(image=image, question=translated_question)
+final_answer(f"The answer is {answer}")
+```
+
+---
+Task: "Identify the oldest person in the `document` and create an image showcasing the result."
+
+Thought: I will use the following tools: `document_qa` to find the oldest person in the document, then `image_generator` to generate an image according to the answer.
+Code:
+```py
+answer = document_qa(document, question="What is the oldest person?")
+print(f"The answer is {answer}.")
+image = image_generator(answer)
+final_answer(image)
+```
+
+---
+Task: "Generate an image using the text given in the variable `caption`."
+
+Thought: I will use the following tool: `image_generator` to generate an image.
+Code:
+```py
+image = image_generator(prompt=caption)
+final_answer(image)
+```
+
+---
+Task: "Summarize the text given in the variable `text` and read it out loud."
+
+Thought: I will use the following tools: `summarizer` to create a summary of the input text, then `text_reader` to read it out loud.
+Code:
+```py
+summarized_text = summarizer(text)
+print(f"Summary: {summarized_text}")
+audio_summary = text_reader(summarized_text)
+final_answer(audio_summary)
+```
+
+---
+Task: "Answer the question in the variable `question` about the text in the variable `text`. Use the answer to generate an image."
+
+Thought: I will use the following tools: `text_qa` to create the answer, then `image_generator` to generate an image according to the answer.
+Code:
+```py
+answer = text_qa(text=text, question=question)
+print(f"The answer is {answer}.")
+image = image_generator(answer)
+final_answer(image)
+```
+
+---
+Task: "Caption the following `image`."
+
+Thought: I will use the following tool: `image_captioner` to generate a caption for the image.
+Code:
+```py
+caption = image_captioner(image)
+final_answer(caption)
+```
+
+---
+Above example were using tools that might not exist for you. You only have acces to those Tools:
+<>
+
+Remember to make sure that variables you use are all defined.
+Be sure to provide a 'Code:\n```' sequence before the code and '```' after, else you will get an error.
+DO NOT pass the arguments as a dict as in 'answer = ask_search_agent({'query': "What is the place where James Bond lives?"})', but use the arguments directly as in 'answer = ask_search_agent(query="What is the place where James Bond lives?")'.
+
+Now Begin! If you solve the task correctly, you will receive a reward of $1,000,000.
+"""
+
+
+DEFAULT_REACT_JSON_SYSTEM_PROMPT = """You are an expert assistant who can solve any task using JSON tool calls. You will be given a task to solve as best you can.
+To do so, you have been given access to the following tools: <>
+The way you use the tools is by specifying a json blob, ending with ''.
+Specifically, this json should have an `action` key (name of the tool to use) and an `action_input` key (input to the tool).
+
+The $ACTION_JSON_BLOB should only contain a SINGLE action, do NOT return a list of multiple actions. It should be formatted in json. Do not try to escape special characters. Here is the template of a valid $ACTION_JSON_BLOB:
+{
+ "action": $TOOL_NAME,
+ "action_input": $INPUT
+}
+
+Make sure to have the $INPUT as a dictionary in the right format for the tool you are using, and do not put variable names as input if you can find the right values.
+
+You should ALWAYS use the following format:
+
+Thought: you should always think about one action to take. Then use the action as follows:
+Action:
+$ACTION_JSON_BLOB
+Observation: the result of the action
+... (this Thought/Action/Observation can repeat N times, you should take several steps when needed. The $ACTION_JSON_BLOB must only use a SINGLE action at a time.)
+
+You can use the result of the previous action as input for the next action.
+The observation will always be a string: it can represent a file, like "image_1.jpg".
+Then you can use it as input for the next action. You can do it for instance as follows:
+
+Observation: "image_1.jpg"
+
+Thought: I need to transform the image that I received in the previous observation to make it green.
+Action:
+{
+ "action": "image_transformer",
+ "action_input": {"image": "image_1.jpg"}
+}
+
+To provide the final answer to the task, use an action blob with "action": "final_answer" tool. It is the only way to complete the task, else you will be stuck on a loop. So your final output should look like this:
+Action:
+{
+ "action": "final_answer",
+ "action_input": {"answer": "insert your final answer here"}
+}
+
+
+Here are a few examples using notional tools:
+---
+Task: "Generate an image of the oldest person in this document."
+
+Thought: I will proceed step by step and use the following tools: `document_qa` to find the oldest person in the document, then `image_generator` to generate an image according to the answer.
+Action:
+{
+ "action": "document_qa",
+ "action_input": {"document": "document.pdf", "question": "Who is the oldest person mentioned?"}
+}
+Observation: "The oldest person in the document is John Doe, a 55 year old lumberjack living in Newfoundland."
+
+
+Thought: I will now generate an image showcasing the oldest person.
+Action:
+{
+ "action": "image_generator",
+ "action_input": {"prompt": "A portrait of John Doe, a 55-year-old man living in Canada."}
+}
+Observation: "image.png"
+
+Thought: I will now return the generated image.
+Action:
+{
+ "action": "final_answer",
+ "action_input": "image.png"
+}
+
+---
+Task: "What is the result of the following operation: 5 + 3 + 1294.678?"
+
+Thought: I will use python code evaluator to compute the result of the operation and then return the final answer using the `final_answer` tool
+Action:
+{
+ "action": "python_interpreter",
+ "action_input": {"code": "5 + 3 + 1294.678"}
+}
+Observation: 1302.678
+
+Thought: Now that I know the result, I will now return it.
+Action:
+{
+ "action": "final_answer",
+ "action_input": "1302.678"
+}
+
+---
+Task: "Which city has the highest population , Guangzhou or Shanghai?"
+
+Thought: I need to get the populations for both cities and compare them: I will use the tool `search` to get the population of both cities.
+Action:
+{
+ "action": "search",
+ "action_input": "Population Guangzhou"
+}
+Observation: ['Guangzhou has a population of 15 million inhabitants as of 2021.']
+
+
+Thought: Now let's get the population of Shanghai using the tool 'search'.
+Action:
+{
+ "action": "search",
+ "action_input": "Population Shanghai"
+}
+Observation: '26 million (2019)'
+
+Thought: Now I know that Shanghai has a larger population. Let's return the result.
+Action:
+{
+ "action": "final_answer",
+ "action_input": "Shanghai"
+}
+
+
+Above example were using notional tools that might not exist for you. You only have acces to those tools:
+<>
+
+Here are the rules you should always follow to solve your task:
+1. ALWAYS provide a 'Thought:' sequence, and an 'Action:' sequence that ends with , else you will fail.
+2. Always use the right arguments for the tools. Never use variable names in the 'action_input' field, use the value instead.
+3. Call a tool only when needed: do not call the search agent if you do not need information, try to solve the task yourself.
+4. Never re-do a tool call that you previously did with the exact same parameters.
+
+Now Begin! If you solve the task correctly, you will receive a reward of $1,000,000.
+"""
+
+
+DEFAULT_REACT_CODE_SYSTEM_PROMPT = """You are an expert assistant who can solve any task using code blobs. You will be given a task to solve as best you can.
+To do so, you have been given access to a list of tools: these tools are basically Python functions which you can call with code.
+To solve the task, you must plan forward to proceed in a series of steps, in a cycle of 'Thought:', 'Code:', and 'Observation:' sequences.
+
+At each step, in the 'Thought:' sequence, you should first explain your reasoning towards solving the task and the tools that you want to use.
+Then in the 'Code:' sequence, you should write the code in simple Python. The code sequence must end with '' sequence.
+During each intermediate step, you can use 'print()' to save whatever important information you will then need.
+These print outputs will then appear in the 'Observation:' field, which will be available as input for the next step.
+In the end you have to return a final answer using the `final_answer` tool.
+
+Here are a few examples using notional tools:
+---
+Task: "Generate an image of the oldest person in this document."
+
+Thought: I will proceed step by step and use the following tools: `document_qa` to find the oldest person in the document, then `image_generator` to generate an image according to the answer.
+Code:
+```py
+answer = document_qa(document=document, question="Who is the oldest person mentioned?")
+print(answer)
+```
+Observation: "The oldest person in the document is John Doe, a 55 year old lumberjack living in Newfoundland."
+
+Thought: I will now generate an image showcasing the oldest person.
+Code:
+```py
+image = image_generator("A portrait of John Doe, a 55-year-old man living in Canada.")
+final_answer(image)
+```
+
+---
+Task: "What is the result of the following operation: 5 + 3 + 1294.678?"
+
+Thought: I will use python code to compute the result of the operation and then return the final answer using the `final_answer` tool
+Code:
+```py
+result = 5 + 3 + 1294.678
+final_answer(result)
+```
+
+---
+Task: "Which city has the highest population: Guangzhou or Shanghai?"
+
+Thought: I need to get the populations for both cities and compare them: I will use the tool `search` to get the population of both cities.
+Code:
+```py
+population_guangzhou = search("Guangzhou population")
+print("Population Guangzhou:", population_guangzhou)
+population_shanghai = search("Shanghai population")
+print("Population Shanghai:", population_shanghai)
+```
+Observation:
+Population Guangzhou: ['Guangzhou has a population of 15 million inhabitants as of 2021.']
+Population Shanghai: '26 million (2019)'
+
+Thought: Now I know that Shanghai has the highest population.
+Code:
+```py
+final_answer("Shanghai")
+```
+
+---
+Task: "What is the current age of the pope, raised to the power 0.36?"
+
+Thought: I will use the tool `wiki` to get the age of the pope, then raise it to the power 0.36.
+Code:
+```py
+pope_age = wiki(query="current pope age")
+print("Pope age:", pope_age)
+```
+Observation:
+Pope age: "The pope Francis is currently 85 years old."
+
+Thought: I know that the pope is 85 years old. Let's compute the result using python code.
+Code:
+```py
+pope_current_age = 85 ** 0.36
+final_answer(pope_current_age)
+```
+
+Above example were using notional tools that might not exist for you. On top of performing computations in the Python code snippets that you create, you have acces to those tools (and no other tool):
+
+<>
+
+<>
+
+Here are the rules you should always follow to solve your task:
+1. Always provide a 'Thought:' sequence, and a 'Code:\n```py' sequence ending with '```' sequence, else you will fail.
+2. Use only variables that you have defined!
+3. Always use the right arguments for the tools. DO NOT pass the arguments as a dict as in 'answer = wiki({'query': "What is the place where James Bond lives?"})', but use the arguments directly as in 'answer = wiki(query="What is the place where James Bond lives?")'.
+4. Take care to not chain too many sequential tool calls in the same code block, especially when the output format is unpredictable. For instance, a call to search has an unpredictable return format, so do not have another tool call that depends on its output in the same block: rather output results with print() to use them in the next block.
+5. Call a tool only when needed, and never re-do a tool call that you previously did with the exact same parameters.
+6. Don't name any new variable with the same name as a tool: for instance don't name a variable 'final_answer'.
+7. Never create any notional variables in our code, as having these in your logs might derail you from the true variables.
+8. You can use imports in your code, but only from the following list of modules: <>
+9. The state persists between code executions: so if in one step you've created variables or imported modules, these will all persist.
+10. Don't give up! You're in charge of solving the task, not providing directions to solve it.
+
+Now Begin! If you solve the task correctly, you will receive a reward of $1,000,000.
+"""
+
+SYSTEM_PROMPT_FACTS = """Below I will present you a task.
+
+You will now build a comprehensive preparatory survey of which facts we have at our disposal and which ones we still need.
+To do so, you will have to read the task and identify things that must be discovered in order to successfully complete it.
+Don't make any assumptions. For each item, provide a thorough reasoning. Here is how you will structure this survey:
+
+---
+### 1. Facts given in the task
+List here the specific facts given in the task that could help you (there might be nothing here).
+
+### 2. Facts to look up
+List here any facts that we may need to look up.
+Also list where to find each of these, for instance a website, a file... - maybe the task contains some sources that you should re-use here.
+
+### 3. Facts to derive
+List here anything that we want to derive from the above by logical reasoning, for instance computation or simulation.
+
+Keep in mind that "facts" will typically be specific names, dates, values, etc. Your answer should use the below headings:
+### 1. Facts given in the task
+### 2. Facts to look up
+### 3. Facts to derive
+Do not add anything else."""
+
+SYSTEM_PROMPT_PLAN = """You are a world expert at making efficient plans to solve any task using a set of carefully crafted tools.
+
+Now for the given task, develop a step-by-step high-level plan taking into account the above inputs and list of facts.
+This plan should involve individual tasks based on the avilable tools, that if executed correctly will yield the correct answer.
+Do not skip steps, do not add any superfluous steps. Only write the high-level plan, DO NOT DETAIL INDIVIDUAL TOOL CALLS.
+After writing the final step of the plan, write the '\n' tag and stop there."""
+
+USER_PROMPT_PLAN = """
+Here is your task:
+
+Task:
+```
+{task}
+```
+
+Your plan can leverage any of these tools:
+{tool_descriptions}
+
+{managed_agents_descriptions}
+
+List of facts that you know:
+```
+{answer_facts}
+```
+
+Now begin! Write your plan below."""
+
+SYSTEM_PROMPT_FACTS_UPDATE = """
+You are a world expert at gathering known and unknown facts based on a conversation.
+Below you will find a task, and ahistory of attempts made to solve the task. You will have to produce a list of these:
+### 1. Facts given in the task
+### 2. Facts that we have learned
+### 3. Facts still to look up
+### 4. Facts still to derive
+Find the task and history below."""
+
+USER_PROMPT_FACTS_UPDATE = """Earlier we've built a list of facts.
+But since in your previous steps you may have learned useful new facts or invalidated some false ones.
+Please update your list of facts based on the previous history, and provide these headings:
+### 1. Facts given in the task
+### 2. Facts that we have learned
+### 3. Facts still to look up
+### 4. Facts still to derive
+
+Now write your new list of facts below."""
+
+SYSTEM_PROMPT_PLAN_UPDATE = """You are a world expert at making efficient plans to solve any task using a set of carefully crafted tools.
+
+You have been given a task:
+```
+{task}
+```
+
+Find below the record of what has been tried so far to solve it. Then you will be asked to make an updated plan to solve the task.
+If the previous tries so far have met some success, you can make an updated plan based on these actions.
+If you are stalled, you can make a completely new plan starting from scratch.
+"""
+
+USER_PROMPT_PLAN_UPDATE = """You're still working towards solving this task:
+```
+{task}
+```
+
+You have access to these tools and only these:
+{tool_descriptions}
+
+{managed_agents_descriptions}
+
+Here is the up to date list of facts that you know:
+```
+{facts_update}
+```
+
+Now for the given task, develop a step-by-step high-level plan taking into account the above inputs and list of facts.
+This plan should involve individual tasks based on the avilable tools, that if executed correctly will yield the correct answer.
+Beware that you have {remaining_steps} steps remaining.
+Do not skip steps, do not add any superfluous steps. Only write the high-level plan, DO NOT DETAIL INDIVIDUAL TOOL CALLS.
+After writing the final step of the plan, write the '\n' tag and stop there.
+
+Now write your new plan below."""
+
+SYSTEM_PROMPT_PLAN_STRUCTURED = """Output a step-by-step plan to solve the task using the given tools.
+This plan should involve individual tasks based on the avilable tools, that if executed correctly will yield the correct answer. Each step should be structured as follows:
+Step #n: {
+ "description":
+ "tool": ,
+ "params": {
+
+ }
+ "output_var":
+}
+Each step must be necessary to reach the final answer. Steps should reuse outputs produced by earlier steps. The last step must be the final answer.
+
+Below are some examples:
+
+Example 1:
+------
+Inputs:
+---
+Task:
+How many encoder blocks were in the first attention-only ML architecture published?
+
+[FACTS LIST]:
+### 1. Facts given in the task
+- The paper first introduced an attention-only ML architecture.
+- The specific information required is the page number where the number of encoder blocks is stated.
+- No local files are provided for access.
+
+### 2. Facts to look up
+- The title and authors of the paper that first introduced an attention-only ML architecture.
+ - Source: Online search (e.g., Google Scholar, arXiv, or other academic databases)
+- The full text of the identified paper.
+ - Source: Online academic repositories (e.g., arXiv, journal websites)
+- The specific page number in the paper where the number of encoder blocks is mentioned.
+ - Source: The content of the identified paper
+
+### 3. Facts to derive
+- By identifying the correct paper and locating the specific page, we will derive the page number where the number of encoder blocks is stated.
+ - Logical steps: Identify the correct paper, access its content, search for the term "encoder blocks," and note the page number where this information is found.
+```
+
+[STEP 1 TOOL CALL]: {'tool_name': 'code interpreter', 'tool_arguments': '# Step 1: Identify the title and authors of the paper that first introduced an attention-only ML architecture.\nanswer = ask_search_agent(query="Can you find the title and authors of the paper that first introduced an attention-only machine learning architecture? Please provide the full citation.")\nprint(answer)'}
+[OUTPUT OF STEP 1] Observation: **Title**: Attention Is All You Need
+**Authors**: Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, Illia Polosukhin
+[STEP 2 TOOL CALL]: {'tool_name': 'code interpreter', 'tool_arguments': '# Step 1: Find the full text of the identified paper on arXiv\\npaper_url = "https://arxiv.org/pdf/1706.03762.pdf"\\nprint(paper_url)'}
+[OUTPUT OF STEP 2] Observation: https://arxiv.org/pdf/1706.03762.pdf
+---
+
+Output plan:
+---
+Step #1: {
+ "description": "Open the PDF of the paper from the provided URL and search within the text of the paper for the mention of "encoder blocks"",
+ "tool": "inspect_file_as_text",
+ "params": {
+ "file_path": "https://arxiv.org/pdf/1706.03762.pdf",
+ "question": "On which page is the number of encoder blocks mentioned?"
+ },
+ "output_var": "page_number"
+}
+
+Step #2: {
+ "description": "Provide the final answer",
+ "tool": "final_answer",
+ "params": {
+ "answer": "{page_number}"
+ },
+ "output_var": ""
+}
+------
+
+Example 2:
+------
+Inputs:
+---
+Task:
+How many golf balls fits into a Boeing-747?
+
+[FACTS LIST]:
+### 1. Facts given in the task
+- The task requires calculating the number of golf balls that fir into a Boeing-747
+### 2. Facts to look up
+- The volume of a golf ball
+- The volume of a Boeing-747
+### 3. Facts to derive
+- Once the volumes are known the final answer can be calculated
+---
+Output plan:
+---
+Step #1: {
+ "description": "Find the volume of a Boeing-747",
+ "tool": "web_search",
+ "params": {
+ "query": "What is the internal volume of a Boeing-747 in cubic meters?"
+ },
+ "output_var": "boeing_volume"
+}
+
+Step #2: {
+ "description": "Find the volume of a standard golf ball",
+ "tool": "ask_search_agent",
+ "params": {
+ "query": "What is the volume of a standard golf ball in cubic centimeters?"
+ },
+ "output_var": "golf_ball_volume"
+}
+
+Step #3: {
+ "description": "Convert the volume of a golf ball from cubic centimeters to cubic meters. Calculate the number of golf balls that fit into the Boeing-747 by dividing the internal volume of the Boeing-747 by the volume of a golf ball.",
+ "tool": "python_code",
+ "params": {
+ "code": "golf_ball_volume_m3 = golf_ball_volume / 1e6\nnumber_of_golf_balls = boeing_volume / golf_ball_volume_m3"
+ },
+ "output_var": "number_of_golf_balls"
+}
+
+Step #4: {
+ "description": "Provide the final answer",
+ "tool": "final_answer",
+ "params": {
+ "answer": "{number_of_golf_balls}"
+ },
+ "output_var": ""
+}
+------
+Above example were using tools that might not exist for you.
+Your goal is to create a plan to solve the task."""
+
+USER_PROMPT_PLAN_STRUCTURED = """
+Here are your inputs:
+
+Task:
+```
+{task}
+```
+
+Your plan can leverage any of these tools:
+{tool_descriptions}
+These tools are Python functions which you can call with code. You also have access to a Python interpreter so you can run Python code.
+
+List of facts that you know:
+```
+{answer_facts}
+```
+
+Now for the given task, create a plan taking into account the list of facts.
+After writing the final step of the plan, write the '\n' tag and stop there. Output the plan only and nothing else."""
+
+SYSTEM_PROMPT_PLAN_UPDATE_STRUCTURED = """Output a step-by-step plan to solve the task using the given tools.
+This plan should involve individual tasks based on the avilable tools, that if executed correctly will yield the correct answer. Each step should be structured as follows:
+Step #n: {{
+ "description":
+ "tool": ,
+ "params": {{
+
+ }}
+ "output_var":
+}}
+Each step must be necessary to reach the final answer. Steps should reuse outputs produced by earlier steps. The last step must be the final answer.
+
+Below are some examples:
+
+Example 1:
+------
+Inputs:
+---
+Task:
+How many encoder blocks were in the first attention-only ML architecture published?
+
+[FACTS LIST]:
+### 1. Facts given in the task
+- The paper first introduced an attention-only ML architecture.
+- The specific information required is the page number where the number of encoder blocks is stated.
+- No local files are provided for access.
+
+### 2. Facts to look up
+- The title and authors of the paper that first introduced an attention-only ML architecture.
+ - Source: Online search (e.g., Google Scholar, arXiv, or other academic databases)
+- The full text of the identified paper.
+ - Source: Online academic repositories (e.g., arXiv, journal websites)
+- The specific page number in the paper where the number of encoder blocks is mentioned.
+ - Source: The content of the identified paper
+
+### 3. Facts to derive
+- By identifying the correct paper and locating the specific page, we will derive the page number where the number of encoder blocks is stated.
+ - Logical steps: Identify the correct paper, access its content, search for the term "encoder blocks," and note the page number where this information is found.
+```
+
+[STEP 1 TOOL CALL]: {{'tool_name': 'code interpreter', 'tool_arguments': '# Step 1: Identify the title and authors of the paper that first introduced an attention-only ML architecture.\nanswer = ask_search_agent(query="Can you find the title and authors of the paper that first introduced an attention-only machine learning architecture? Please provide the full citation.")\nprint(answer)'}}
+[OUTPUT OF STEP 1] Observation: **Title**: Attention Is All You Need
+**Authors**: Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, Illia Polosukhin
+[STEP 2 TOOL CALL]: {{'tool_name': 'code interpreter', 'tool_arguments': '# Step 1: Find the full text of the identified paper on arXiv\\npaper_url = "https://arxiv.org/pdf/1706.03762.pdf"\\nprint(paper_url)'}}
+[OUTPUT OF STEP 2] Observation: https://arxiv.org/pdf/1706.03762.pdf
+---
+
+Output plan:
+---
+Step #1: {{
+ "description": "Open the PDF of the paper from the provided URL and search within the text of the paper for the mention of "encoder blocks"",
+ "tool": "inspect_file_as_text",
+ "params": {{
+ "file_path": "https://arxiv.org/pdf/1706.03762.pdf",
+ "question": "On which page is the number of encoder blocks mentioned?"
+ }},
+ "output_var": "page_number"
+}}
+
+Step #2: {{
+ "description": "Provide the final answer",
+ "tool": "final_answer",
+ "params": {{
+ "answer": "{{page_number}}"
+ }},
+ "output_var": ""
+}}
+------
+
+Example 2:
+------
+Inputs:
+---
+Task:
+How many golf balls fits into a Boeing-747?
+
+[FACTS LIST]:
+### 1. Facts given in the task
+- The task requires calculating the number of golf balls that fir into a Boeing-747
+### 2. Facts to look up
+- The volume of a golf ball
+- The volume of a Boeing-747
+### 3. Facts to derive
+- Once the volumes are known the final answer can be calculated
+---
+Output plan:
+---
+Step #1: {{
+ "description": "Find the volume of a Boeing-747",
+ "tool": "web_search",
+ "params": {{
+ "query": "What is the internal volume of a Boeing-747 in cubic meters?"
+ }},
+ "output_var": "boeing_volume"
+}}
+
+Step #2: {{
+ "description": "Find the volume of a standard golf ball",
+ "tool": "ask_search_agent",
+ "params": {{
+ "query": "What is the volume of a standard golf ball in cubic centimeters?"
+ }},
+ "output_var": "golf_ball_volume"
+}}
+
+Step #3: {{
+ "description": "Convert the volume of a golf ball from cubic centimeters to cubic meters. Calculate the number of golf balls that fit into the Boeing-747 by dividing the internal volume of the Boeing-747 by the volume of a golf ball.",
+ "tool": "python_code",
+ "params": {{
+ "code": "golf_ball_volume_m3 = golf_ball_volume / 1e6\nnumber_of_golf_balls = boeing_volume / golf_ball_volume_m3"
+ }},
+ "output_var": "number_of_golf_balls"
+}}
+
+Step #4: {{
+ "description": "Provide the final answer",
+ "tool": "final_answer",
+ "params": {{
+ "answer": "{{number_of_golf_balls}}"
+ }},
+ "output_var": ""
+}}
+------
+Above example were using tools that might not exist for you.
+Find below the record of what has been tried so far to solve it. Your goal is to create an updated plan to solve the task."""
+
+USER_PROMPT_PLAN_UPDATE_STRUCTURED = """
+Here are your inputs:
+
+Task:
+```
+{task}
+```
+
+Your plan can leverage any of these tools:
+{tool_descriptions}
+These tools are Python functions which you can call with code. You also have access to a Python interpreter so you can run Python code.
+
+List of facts that you know:
+```
+{facts_update}
+```
+
+Now for the given task, create a plan taking into account the above inputs and list of facts.
+Beware that you have {remaining_steps} steps remaining.
+After writing the final step of the plan, write the '\n' tag and stop there. Output the plan only and nothing else."""
+
+PLAN_UPDATE_FINAL_PLAN_REDACTION = """I still need to solve the task I was given:
+```
+{task}
+```
+
+Here is my new/updated plan of action to solve the task:
+```
+{plan_update}
+```"""
+
+SUPPORTED_PLAN_TYPES = ["default", "structured"]
+
+PROMPTS_FOR_INITIAL_PLAN = {
+ "default": {"system": SYSTEM_PROMPT_PLAN, "user": USER_PROMPT_PLAN},
+ "structured": {"system": SYSTEM_PROMPT_PLAN_STRUCTURED, "user": USER_PROMPT_PLAN_STRUCTURED},
+}
+
+PROMPTS_FOR_PLAN_UPDATE = {
+ "default": {"system": SYSTEM_PROMPT_PLAN_UPDATE, "user": USER_PROMPT_PLAN_UPDATE},
+ "structured": {"system": SYSTEM_PROMPT_PLAN_UPDATE_STRUCTURED, "user": USER_PROMPT_PLAN_UPDATE_STRUCTURED},
+}
diff --git a/agents/python_interpreter.py b/agents/python_interpreter.py
new file mode 100644
index 0000000..6e90f35
--- /dev/null
+++ b/agents/python_interpreter.py
@@ -0,0 +1,908 @@
+#!/usr/bin/env python
+# coding=utf-8
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import ast
+import builtins
+import difflib
+from collections.abc import Mapping
+from importlib import import_module
+from typing import Any, Callable, Dict, List, Optional
+
+import numpy as np
+
+from ..utils import is_pandas_available
+
+
+if is_pandas_available():
+ import pandas as pd
+
+
+class InterpreterError(ValueError):
+ """
+ An error raised when the interpretor cannot evaluate a Python expression, due to syntax error or unsupported
+ operations.
+ """
+
+ pass
+
+
+ERRORS = {
+ name: getattr(builtins, name)
+ for name in dir(builtins)
+ if isinstance(getattr(builtins, name), type) and issubclass(getattr(builtins, name), BaseException)
+}
+
+
+LIST_SAFE_MODULES = [
+ "random",
+ "collections",
+ "math",
+ "time",
+ "queue",
+ "itertools",
+ "re",
+ "stat",
+ "statistics",
+ "unicodedata",
+]
+
+PRINT_OUTPUTS, MAX_LEN_OUTPUT = "", 50000
+OPERATIONS_COUNT, MAX_OPERATIONS = 0, 10000000
+
+
+class BreakException(Exception):
+ pass
+
+
+class ContinueException(Exception):
+ pass
+
+
+class ReturnException(Exception):
+ def __init__(self, value):
+ self.value = value
+
+
+def get_iterable(obj):
+ if isinstance(obj, list):
+ return obj
+ elif hasattr(obj, "__iter__"):
+ return list(obj)
+ else:
+ raise InterpreterError("Object is not iterable")
+
+
+def evaluate_unaryop(expression, state, static_tools, custom_tools):
+ operand = evaluate_ast(expression.operand, state, static_tools, custom_tools)
+ if isinstance(expression.op, ast.USub):
+ return -operand
+ elif isinstance(expression.op, ast.UAdd):
+ return operand
+ elif isinstance(expression.op, ast.Not):
+ return not operand
+ elif isinstance(expression.op, ast.Invert):
+ return ~operand
+ else:
+ raise InterpreterError(f"Unary operation {expression.op.__class__.__name__} is not supported.")
+
+
+def evaluate_lambda(lambda_expression, state, static_tools, custom_tools):
+ args = [arg.arg for arg in lambda_expression.args.args]
+
+ def lambda_func(*values):
+ new_state = state.copy()
+ for arg, value in zip(args, values):
+ new_state[arg] = value
+ return evaluate_ast(lambda_expression.body, new_state, static_tools, custom_tools)
+
+ return lambda_func
+
+
+def evaluate_while(while_loop, state, static_tools, custom_tools):
+ max_iterations = 1000
+ iterations = 0
+ while evaluate_ast(while_loop.test, state, static_tools, custom_tools):
+ for node in while_loop.body:
+ try:
+ evaluate_ast(node, state, static_tools, custom_tools)
+ except BreakException:
+ return None
+ except ContinueException:
+ break
+ iterations += 1
+ if iterations > max_iterations:
+ raise InterpreterError(f"Maximum number of {max_iterations} iterations in While loop exceeded")
+ return None
+
+
+def create_function(func_def, state, static_tools, custom_tools):
+ def new_func(*args, **kwargs):
+ func_state = state.copy()
+ arg_names = [arg.arg for arg in func_def.args.args]
+ default_values = [evaluate_ast(d, state, static_tools, custom_tools) for d in func_def.args.defaults]
+
+ # Apply default values
+ defaults = dict(zip(arg_names[-len(default_values) :], default_values))
+
+ # Set positional arguments
+ for name, value in zip(arg_names, args):
+ func_state[name] = value
+
+ # # Set keyword arguments
+ for name, value in kwargs.items():
+ func_state[name] = value
+
+ # Handle variable arguments
+ if func_def.args.vararg:
+ vararg_name = func_def.args.vararg.arg
+ func_state[vararg_name] = args
+
+ if func_def.args.kwarg:
+ kwarg_name = func_def.args.kwarg.arg
+ func_state[kwarg_name] = kwargs
+
+ # Set default values for arguments that were not provided
+ for name, value in defaults.items():
+ if name not in func_state:
+ func_state[name] = value
+
+ # Update function state with self and __class__
+ if func_def.args.args and func_def.args.args[0].arg == "self":
+ if args:
+ func_state["self"] = args[0]
+ func_state["__class__"] = args[0].__class__
+
+ result = None
+ try:
+ for stmt in func_def.body:
+ result = evaluate_ast(stmt, func_state, static_tools, custom_tools)
+ except ReturnException as e:
+ result = e.value
+ return result
+
+ return new_func
+
+
+def create_class(class_name, class_bases, class_body):
+ class_dict = {}
+ for key, value in class_body.items():
+ class_dict[key] = value
+ return type(class_name, tuple(class_bases), class_dict)
+
+
+def evaluate_function_def(func_def, state, static_tools, custom_tools):
+ custom_tools[func_def.name] = create_function(func_def, state, static_tools, custom_tools)
+ return custom_tools[func_def.name]
+
+
+def evaluate_class_def(class_def, state, static_tools, custom_tools):
+ class_name = class_def.name
+ bases = [evaluate_ast(base, state, static_tools, custom_tools) for base in class_def.bases]
+ class_dict = {}
+
+ for stmt in class_def.body:
+ if isinstance(stmt, ast.FunctionDef):
+ class_dict[stmt.name] = evaluate_function_def(stmt, state, static_tools, custom_tools)
+ elif isinstance(stmt, ast.Assign):
+ for target in stmt.targets:
+ if isinstance(target, ast.Name):
+ class_dict[target.id] = evaluate_ast(stmt.value, state, static_tools, custom_tools)
+ elif isinstance(target, ast.Attribute):
+ class_dict[target.attr] = evaluate_ast(stmt.value, state, static_tools, custom_tools)
+ else:
+ raise InterpreterError(f"Unsupported statement in class body: {stmt.__class__.__name__}")
+
+ new_class = type(class_name, tuple(bases), class_dict)
+ state[class_name] = new_class
+ return new_class
+
+
+def evaluate_augassign(expression, state, static_tools, custom_tools):
+ # Helper function to get current value and set new value based on the target type
+ def get_current_value(target):
+ if isinstance(target, ast.Name):
+ return state.get(target.id, 0)
+ elif isinstance(target, ast.Subscript):
+ obj = evaluate_ast(target.value, state, static_tools, custom_tools)
+ key = evaluate_ast(target.slice, state, static_tools, custom_tools)
+ return obj[key]
+ elif isinstance(target, ast.Attribute):
+ obj = evaluate_ast(target.value, state, static_tools, custom_tools)
+ return getattr(obj, target.attr)
+ elif isinstance(target, ast.Tuple):
+ return tuple(get_current_value(elt) for elt in target.elts)
+ elif isinstance(target, ast.List):
+ return [get_current_value(elt) for elt in target.elts]
+ else:
+ raise InterpreterError("AugAssign not supported for {type(target)} targets.")
+
+ current_value = get_current_value(expression.target)
+ value_to_add = evaluate_ast(expression.value, state, static_tools, custom_tools)
+
+ # Determine the operation and apply it
+ if isinstance(expression.op, ast.Add):
+ if isinstance(current_value, list):
+ if not isinstance(value_to_add, list):
+ raise InterpreterError(f"Cannot add non-list value {value_to_add} to a list.")
+ updated_value = current_value + value_to_add
+ else:
+ updated_value = current_value + value_to_add
+ elif isinstance(expression.op, ast.Sub):
+ updated_value = current_value - value_to_add
+ elif isinstance(expression.op, ast.Mult):
+ updated_value = current_value * value_to_add
+ elif isinstance(expression.op, ast.Div):
+ updated_value = current_value / value_to_add
+ elif isinstance(expression.op, ast.Mod):
+ updated_value = current_value % value_to_add
+ elif isinstance(expression.op, ast.Pow):
+ updated_value = current_value**value_to_add
+ elif isinstance(expression.op, ast.FloorDiv):
+ updated_value = current_value // value_to_add
+ elif isinstance(expression.op, ast.BitAnd):
+ updated_value = current_value & value_to_add
+ elif isinstance(expression.op, ast.BitOr):
+ updated_value = current_value | value_to_add
+ elif isinstance(expression.op, ast.BitXor):
+ updated_value = current_value ^ value_to_add
+ elif isinstance(expression.op, ast.LShift):
+ updated_value = current_value << value_to_add
+ elif isinstance(expression.op, ast.RShift):
+ updated_value = current_value >> value_to_add
+ else:
+ raise InterpreterError(f"Operation {type(expression.op).__name__} is not supported.")
+
+ # Update the state
+ set_value(expression.target, updated_value, state, static_tools, custom_tools)
+
+ return updated_value
+
+
+def evaluate_boolop(node, state, static_tools, custom_tools):
+ if isinstance(node.op, ast.And):
+ for value in node.values:
+ if not evaluate_ast(value, state, static_tools, custom_tools):
+ return False
+ return True
+ elif isinstance(node.op, ast.Or):
+ for value in node.values:
+ if evaluate_ast(value, state, static_tools, custom_tools):
+ return True
+ return False
+
+
+def evaluate_binop(binop, state, static_tools, custom_tools):
+ # Recursively evaluate the left and right operands
+ left_val = evaluate_ast(binop.left, state, static_tools, custom_tools)
+ right_val = evaluate_ast(binop.right, state, static_tools, custom_tools)
+
+ # Determine the operation based on the type of the operator in the BinOp
+ if isinstance(binop.op, ast.Add):
+ return left_val + right_val
+ elif isinstance(binop.op, ast.Sub):
+ return left_val - right_val
+ elif isinstance(binop.op, ast.Mult):
+ return left_val * right_val
+ elif isinstance(binop.op, ast.Div):
+ return left_val / right_val
+ elif isinstance(binop.op, ast.Mod):
+ return left_val % right_val
+ elif isinstance(binop.op, ast.Pow):
+ return left_val**right_val
+ elif isinstance(binop.op, ast.FloorDiv):
+ return left_val // right_val
+ elif isinstance(binop.op, ast.BitAnd):
+ return left_val & right_val
+ elif isinstance(binop.op, ast.BitOr):
+ return left_val | right_val
+ elif isinstance(binop.op, ast.BitXor):
+ return left_val ^ right_val
+ elif isinstance(binop.op, ast.LShift):
+ return left_val << right_val
+ elif isinstance(binop.op, ast.RShift):
+ return left_val >> right_val
+ else:
+ raise NotImplementedError(f"Binary operation {type(binop.op).__name__} is not implemented.")
+
+
+def evaluate_assign(assign, state, static_tools, custom_tools):
+ result = evaluate_ast(assign.value, state, static_tools, custom_tools)
+ if len(assign.targets) == 1:
+ target = assign.targets[0]
+ set_value(target, result, state, static_tools, custom_tools)
+ else:
+ if len(assign.targets) != len(result):
+ raise InterpreterError(f"Assign failed: expected {len(result)} values but got {len(assign.targets)}.")
+ expanded_values = []
+ for tgt in assign.targets:
+ if isinstance(tgt, ast.Starred):
+ expanded_values.extend(result)
+ else:
+ expanded_values.append(result)
+ for tgt, val in zip(assign.targets, expanded_values):
+ set_value(tgt, val, state, static_tools, custom_tools)
+ return result
+
+
+def set_value(target, value, state, static_tools, custom_tools):
+ if isinstance(target, ast.Name):
+ if target.id in static_tools:
+ raise InterpreterError(f"Cannot assign to name '{target.id}': doing this would erase the existing tool!")
+ state[target.id] = value
+ elif isinstance(target, ast.Tuple):
+ if not isinstance(value, tuple):
+ if hasattr(value, "__iter__") and not isinstance(value, (str, bytes)):
+ value = tuple(value)
+ else:
+ raise InterpreterError("Cannot unpack non-tuple value")
+ if len(target.elts) != len(value):
+ raise InterpreterError("Cannot unpack tuple of wrong size")
+ for i, elem in enumerate(target.elts):
+ set_value(elem, value[i], state, static_tools, custom_tools)
+ elif isinstance(target, ast.Subscript):
+ obj = evaluate_ast(target.value, state, static_tools, custom_tools)
+ key = evaluate_ast(target.slice, state, static_tools, custom_tools)
+ obj[key] = value
+ elif isinstance(target, ast.Attribute):
+ obj = evaluate_ast(target.value, state, static_tools, custom_tools)
+ setattr(obj, target.attr, value)
+
+
+def evaluate_call(call, state, static_tools, custom_tools):
+ if not (isinstance(call.func, ast.Attribute) or isinstance(call.func, ast.Name)):
+ raise InterpreterError(f"This is not a correct function: {call.func}).")
+ if isinstance(call.func, ast.Attribute):
+ obj = evaluate_ast(call.func.value, state, static_tools, custom_tools)
+ func_name = call.func.attr
+ if not hasattr(obj, func_name):
+ raise InterpreterError(f"Object {obj} has no attribute {func_name}")
+ func = getattr(obj, func_name)
+
+ elif isinstance(call.func, ast.Name):
+ func_name = call.func.id
+ if func_name in state:
+ func = state[func_name]
+ elif func_name in static_tools:
+ func = static_tools[func_name]
+ elif func_name in custom_tools:
+ func = custom_tools[func_name]
+ elif func_name in ERRORS:
+ func = ERRORS[func_name]
+ else:
+ raise InterpreterError(
+ f"It is not permitted to evaluate other functions than the provided tools or functions defined in previous code (tried to execute {call.func.id})."
+ )
+
+ args = []
+ for arg in call.args:
+ if isinstance(arg, ast.Starred):
+ args.extend(evaluate_ast(arg.value, state, static_tools, custom_tools))
+ else:
+ args.append(evaluate_ast(arg, state, static_tools, custom_tools))
+
+ args = []
+ for arg in call.args:
+ if isinstance(arg, ast.Starred):
+ unpacked = evaluate_ast(arg.value, state, static_tools, custom_tools)
+ if not hasattr(unpacked, "__iter__") or isinstance(unpacked, (str, bytes)):
+ raise InterpreterError(f"Cannot unpack non-iterable value {unpacked}")
+ args.extend(unpacked)
+ else:
+ args.append(evaluate_ast(arg, state, static_tools, custom_tools))
+
+ kwargs = {keyword.arg: evaluate_ast(keyword.value, state, static_tools, custom_tools) for keyword in call.keywords}
+
+ if isinstance(func, type) and len(func.__module__.split(".")) > 1: # Check for user-defined classes
+ # Instantiate the class using its constructor
+ obj = func.__new__(func) # Create a new instance of the class
+ if hasattr(obj, "__init__"): # Check if the class has an __init__ method
+ obj.__init__(*args, **kwargs) # Call the __init__ method correctly
+ return obj
+ else:
+ if func_name == "super":
+ if not args:
+ if "__class__" in state and "self" in state:
+ return super(state["__class__"], state["self"])
+ else:
+ raise InterpreterError("super() needs at least one argument")
+ cls = args[0]
+ if not isinstance(cls, type):
+ raise InterpreterError("super() argument 1 must be type")
+ if len(args) == 1:
+ return super(cls)
+ elif len(args) == 2:
+ instance = args[1]
+ return super(cls, instance)
+ else:
+ raise InterpreterError("super() takes at most 2 arguments")
+ else:
+ if func_name == "print":
+ output = " ".join(map(str, args))
+ global PRINT_OUTPUTS
+ PRINT_OUTPUTS += output + "\n"
+ # cap the number of lines
+ return None
+ else: # Assume it's a callable object
+ output = func(*args, **kwargs)
+ return output
+
+
+def evaluate_subscript(subscript, state, static_tools, custom_tools):
+ index = evaluate_ast(subscript.slice, state, static_tools, custom_tools)
+ value = evaluate_ast(subscript.value, state, static_tools, custom_tools)
+
+ if isinstance(value, str) and isinstance(index, str):
+ raise InterpreterError("You're trying to subscript a string with a string index, which is impossible")
+ if isinstance(value, pd.core.indexing._LocIndexer):
+ parent_object = value.obj
+ return parent_object.loc[index]
+ if isinstance(value, (pd.DataFrame, pd.Series, np.ndarray)):
+ return value[index]
+ elif isinstance(value, pd.core.groupby.generic.DataFrameGroupBy):
+ return value[index]
+ elif isinstance(index, slice):
+ return value[index]
+ elif isinstance(value, (list, tuple)):
+ if not (-len(value) <= index < len(value)):
+ raise InterpreterError(f"Index {index} out of bounds for list of length {len(value)}")
+ return value[int(index)]
+ elif isinstance(value, str):
+ if not (-len(value) <= index < len(value)):
+ raise InterpreterError(f"Index {index} out of bounds for string of length {len(value)}")
+ return value[index]
+ elif index in value:
+ return value[index]
+ elif isinstance(index, str) and isinstance(value, Mapping):
+ close_matches = difflib.get_close_matches(index, list(value.keys()))
+ if len(close_matches) > 0:
+ return value[close_matches[0]]
+ raise InterpreterError(f"Could not index {value} with '{index}'.")
+
+
+def evaluate_name(name, state, static_tools, custom_tools):
+ if name.id in state:
+ return state[name.id]
+ elif name.id in static_tools:
+ return static_tools[name.id]
+ elif name.id in ERRORS:
+ return ERRORS[name.id]
+ close_matches = difflib.get_close_matches(name.id, list(state.keys()))
+ if len(close_matches) > 0:
+ return state[close_matches[0]]
+ raise InterpreterError(f"The variable `{name.id}` is not defined.")
+
+
+def evaluate_condition(condition, state, static_tools, custom_tools):
+ left = evaluate_ast(condition.left, state, static_tools, custom_tools)
+ comparators = [evaluate_ast(c, state, static_tools, custom_tools) for c in condition.comparators]
+ ops = [type(op) for op in condition.ops]
+
+ result = True
+ current_left = left
+
+ for op, comparator in zip(ops, comparators):
+ if op == ast.Eq:
+ current_result = current_left == comparator
+ elif op == ast.NotEq:
+ current_result = current_left != comparator
+ elif op == ast.Lt:
+ current_result = current_left < comparator
+ elif op == ast.LtE:
+ current_result = current_left <= comparator
+ elif op == ast.Gt:
+ current_result = current_left > comparator
+ elif op == ast.GtE:
+ current_result = current_left >= comparator
+ elif op == ast.Is:
+ current_result = current_left is comparator
+ elif op == ast.IsNot:
+ current_result = current_left is not comparator
+ elif op == ast.In:
+ current_result = current_left in comparator
+ elif op == ast.NotIn:
+ current_result = current_left not in comparator
+ else:
+ raise InterpreterError(f"Operator not supported: {op}")
+
+ result = result & current_result
+ current_left = comparator
+
+ if isinstance(result, bool) and not result:
+ break
+
+ return result if isinstance(result, (bool, pd.Series)) else result.all()
+
+
+def evaluate_if(if_statement, state, static_tools, custom_tools):
+ result = None
+ test_result = evaluate_ast(if_statement.test, state, static_tools, custom_tools)
+ if test_result:
+ for line in if_statement.body:
+ line_result = evaluate_ast(line, state, static_tools, custom_tools)
+ if line_result is not None:
+ result = line_result
+ else:
+ for line in if_statement.orelse:
+ line_result = evaluate_ast(line, state, static_tools, custom_tools)
+ if line_result is not None:
+ result = line_result
+ return result
+
+
+def evaluate_for(for_loop, state, static_tools, custom_tools):
+ result = None
+ iterator = evaluate_ast(for_loop.iter, state, static_tools, custom_tools)
+ for counter in iterator:
+ set_value(for_loop.target, counter, state, static_tools, custom_tools)
+ for node in for_loop.body:
+ try:
+ line_result = evaluate_ast(node, state, static_tools, custom_tools)
+ if line_result is not None:
+ result = line_result
+ except BreakException:
+ break
+ except ContinueException:
+ continue
+ else:
+ continue
+ break
+ return result
+
+
+def evaluate_listcomp(listcomp, state, static_tools, custom_tools):
+ def inner_evaluate(generators, index, current_state):
+ if index >= len(generators):
+ return [evaluate_ast(listcomp.elt, current_state, static_tools, custom_tools)]
+ generator = generators[index]
+ iter_value = evaluate_ast(generator.iter, current_state, static_tools, custom_tools)
+ result = []
+ for value in iter_value:
+ new_state = current_state.copy()
+ if isinstance(generator.target, ast.Tuple):
+ for idx, elem in enumerate(generator.target.elts):
+ new_state[elem.id] = value[idx]
+ else:
+ new_state[generator.target.id] = value
+ if all(evaluate_ast(if_clause, new_state, static_tools, custom_tools) for if_clause in generator.ifs):
+ result.extend(inner_evaluate(generators, index + 1, new_state))
+ return result
+
+ return inner_evaluate(listcomp.generators, 0, state)
+
+
+def evaluate_try(try_node, state, static_tools, custom_tools):
+ try:
+ for stmt in try_node.body:
+ evaluate_ast(stmt, state, static_tools, custom_tools)
+ except Exception as e:
+ matched = False
+ for handler in try_node.handlers:
+ if handler.type is None or isinstance(e, evaluate_ast(handler.type, state, static_tools, custom_tools)):
+ matched = True
+ if handler.name:
+ state[handler.name] = e
+ for stmt in handler.body:
+ evaluate_ast(stmt, state, static_tools, custom_tools)
+ break
+ if not matched:
+ raise e
+ else:
+ if try_node.orelse:
+ for stmt in try_node.orelse:
+ evaluate_ast(stmt, state, static_tools, custom_tools)
+ finally:
+ if try_node.finalbody:
+ for stmt in try_node.finalbody:
+ evaluate_ast(stmt, state, static_tools, custom_tools)
+
+
+def evaluate_raise(raise_node, state, static_tools, custom_tools):
+ if raise_node.exc is not None:
+ exc = evaluate_ast(raise_node.exc, state, static_tools, custom_tools)
+ else:
+ exc = None
+ if raise_node.cause is not None:
+ cause = evaluate_ast(raise_node.cause, state, static_tools, custom_tools)
+ else:
+ cause = None
+ if exc is not None:
+ if cause is not None:
+ raise exc from cause
+ else:
+ raise exc
+ else:
+ raise InterpreterError("Re-raise is not supported without an active exception")
+
+
+def evaluate_assert(assert_node, state, static_tools, custom_tools):
+ test_result = evaluate_ast(assert_node.test, state, static_tools, custom_tools)
+ if not test_result:
+ if assert_node.msg:
+ msg = evaluate_ast(assert_node.msg, state, static_tools, custom_tools)
+ raise AssertionError(msg)
+ else:
+ # Include the failing condition in the assertion message
+ test_code = ast.unparse(assert_node.test)
+ raise AssertionError(f"Assertion failed: {test_code}")
+
+
+def evaluate_with(with_node, state, static_tools, custom_tools):
+ contexts = []
+ for item in with_node.items:
+ context_expr = evaluate_ast(item.context_expr, state, static_tools, custom_tools)
+ if item.optional_vars:
+ state[item.optional_vars.id] = context_expr.__enter__()
+ contexts.append(state[item.optional_vars.id])
+ else:
+ context_var = context_expr.__enter__()
+ contexts.append(context_var)
+
+ try:
+ for stmt in with_node.body:
+ evaluate_ast(stmt, state, static_tools, custom_tools)
+ except Exception as e:
+ for context in reversed(contexts):
+ context.__exit__(type(e), e, e.__traceback__)
+ raise
+ else:
+ for context in reversed(contexts):
+ context.__exit__(None, None, None)
+
+
+def import_modules(expression, state, authorized_imports):
+ def check_module_authorized(module_name):
+ module_path = module_name.split(".")
+ module_subpaths = [".".join(module_path[:i]) for i in range(1, len(module_path) + 1)]
+ return any(subpath in authorized_imports for subpath in module_subpaths)
+
+ if isinstance(expression, ast.Import):
+ for alias in expression.names:
+ if check_module_authorized(alias.name):
+ module = import_module(alias.name)
+ state[alias.asname or alias.name] = module
+ else:
+ raise InterpreterError(
+ f"Import of {alias.name} is not allowed. Authorized imports are: {str(authorized_imports)}"
+ )
+ return None
+ elif isinstance(expression, ast.ImportFrom):
+ if check_module_authorized(expression.module):
+ module = __import__(expression.module, fromlist=[alias.name for alias in expression.names])
+ for alias in expression.names:
+ state[alias.asname or alias.name] = getattr(module, alias.name)
+ else:
+ raise InterpreterError(f"Import from {expression.module} is not allowed.")
+ return None
+
+
+def evaluate_dictcomp(dictcomp, state, static_tools, custom_tools):
+ result = {}
+ for gen in dictcomp.generators:
+ iter_value = evaluate_ast(gen.iter, state, static_tools, custom_tools)
+ for value in iter_value:
+ new_state = state.copy()
+ set_value(gen.target, value, new_state, static_tools, custom_tools)
+ if all(evaluate_ast(if_clause, new_state, static_tools, custom_tools) for if_clause in gen.ifs):
+ key = evaluate_ast(dictcomp.key, new_state, static_tools, custom_tools)
+ val = evaluate_ast(dictcomp.value, new_state, static_tools, custom_tools)
+ result[key] = val
+ return result
+
+
+def evaluate_ast(
+ expression: ast.AST,
+ state: Dict[str, Any],
+ static_tools: Dict[str, Callable],
+ custom_tools: Dict[str, Callable],
+ authorized_imports: List[str] = LIST_SAFE_MODULES,
+):
+ """
+ Evaluate an abstract syntax tree using the content of the variables stored in a state and only evaluating a given
+ set of functions.
+
+ This function will recurse trough the nodes of the tree provided.
+
+ Args:
+ expression (`ast.AST`):
+ The code to evaluate, as an abstract syntax tree.
+ state (`Dict[str, Any]`):
+ A dictionary mapping variable names to values. The `state` is updated if need be when the evaluation
+ encounters assignements.
+ static_tools (`Dict[str, Callable]`):
+ Functions that may be called during the evaluation. Trying to change one of these static_tools will raise an error.
+ custom_tools (`Dict[str, Callable]`):
+ Functions that may be called during the evaluation. These static_tools can be overwritten.
+ authorized_imports (`List[str]`):
+ The list of modules that can be imported by the code. By default, only a few safe modules are allowed.
+ Add more at your own risk!
+ """
+ global OPERATIONS_COUNT
+ if OPERATIONS_COUNT >= MAX_OPERATIONS:
+ raise InterpreterError(
+ f"Reached the max number of operations of {MAX_OPERATIONS}. Maybe there is an infinite loop somewhere in the code, or you're just asking too many calculations."
+ )
+ OPERATIONS_COUNT += 1
+ if isinstance(expression, ast.Assign):
+ # Assignement -> we evaluate the assignment which should update the state
+ # We return the variable assigned as it may be used to determine the final result.
+ return evaluate_assign(expression, state, static_tools, custom_tools)
+ elif isinstance(expression, ast.AugAssign):
+ return evaluate_augassign(expression, state, static_tools, custom_tools)
+ elif isinstance(expression, ast.Call):
+ # Function call -> we return the value of the function call
+ return evaluate_call(expression, state, static_tools, custom_tools)
+ elif isinstance(expression, ast.Constant):
+ # Constant -> just return the value
+ return expression.value
+ elif isinstance(expression, ast.Tuple):
+ return tuple(evaluate_ast(elt, state, static_tools, custom_tools) for elt in expression.elts)
+ elif isinstance(expression, (ast.ListComp, ast.GeneratorExp)):
+ return evaluate_listcomp(expression, state, static_tools, custom_tools)
+ elif isinstance(expression, ast.UnaryOp):
+ return evaluate_unaryop(expression, state, static_tools, custom_tools)
+ elif isinstance(expression, ast.Starred):
+ return evaluate_ast(expression.value, state, static_tools, custom_tools)
+ elif isinstance(expression, ast.BoolOp):
+ # Boolean operation -> evaluate the operation
+ return evaluate_boolop(expression, state, static_tools, custom_tools)
+ elif isinstance(expression, ast.Break):
+ raise BreakException()
+ elif isinstance(expression, ast.Continue):
+ raise ContinueException()
+ elif isinstance(expression, ast.BinOp):
+ # Binary operation -> execute operation
+ return evaluate_binop(expression, state, static_tools, custom_tools)
+ elif isinstance(expression, ast.Compare):
+ # Comparison -> evaluate the comparison
+ return evaluate_condition(expression, state, static_tools, custom_tools)
+ elif isinstance(expression, ast.Lambda):
+ return evaluate_lambda(expression, state, static_tools, custom_tools)
+ elif isinstance(expression, ast.FunctionDef):
+ return evaluate_function_def(expression, state, static_tools, custom_tools)
+ elif isinstance(expression, ast.Dict):
+ # Dict -> evaluate all keys and values
+ keys = [evaluate_ast(k, state, static_tools, custom_tools) for k in expression.keys]
+ values = [evaluate_ast(v, state, static_tools, custom_tools) for v in expression.values]
+ return dict(zip(keys, values))
+ elif isinstance(expression, ast.Expr):
+ # Expression -> evaluate the content
+ return evaluate_ast(expression.value, state, static_tools, custom_tools)
+ elif isinstance(expression, ast.For):
+ # For loop -> execute the loop
+ return evaluate_for(expression, state, static_tools, custom_tools)
+ elif isinstance(expression, ast.FormattedValue):
+ # Formatted value (part of f-string) -> evaluate the content and return
+ return evaluate_ast(expression.value, state, static_tools, custom_tools)
+ elif isinstance(expression, ast.If):
+ # If -> execute the right branch
+ return evaluate_if(expression, state, static_tools, custom_tools)
+ elif hasattr(ast, "Index") and isinstance(expression, ast.Index):
+ return evaluate_ast(expression.value, state, static_tools, custom_tools)
+ elif isinstance(expression, ast.JoinedStr):
+ return "".join([str(evaluate_ast(v, state, static_tools, custom_tools)) for v in expression.values])
+ elif isinstance(expression, ast.List):
+ # List -> evaluate all elements
+ return [evaluate_ast(elt, state, static_tools, custom_tools) for elt in expression.elts]
+ elif isinstance(expression, ast.Name):
+ # Name -> pick up the value in the state
+ return evaluate_name(expression, state, static_tools, custom_tools)
+ elif isinstance(expression, ast.Subscript):
+ # Subscript -> return the value of the indexing
+ return evaluate_subscript(expression, state, static_tools, custom_tools)
+ elif isinstance(expression, ast.IfExp):
+ test_val = evaluate_ast(expression.test, state, static_tools, custom_tools)
+ if test_val:
+ return evaluate_ast(expression.body, state, static_tools, custom_tools)
+ else:
+ return evaluate_ast(expression.orelse, state, static_tools, custom_tools)
+ elif isinstance(expression, ast.Attribute):
+ value = evaluate_ast(expression.value, state, static_tools, custom_tools)
+ return getattr(value, expression.attr)
+ elif isinstance(expression, ast.Slice):
+ return slice(
+ evaluate_ast(expression.lower, state, static_tools, custom_tools)
+ if expression.lower is not None
+ else None,
+ evaluate_ast(expression.upper, state, static_tools, custom_tools)
+ if expression.upper is not None
+ else None,
+ evaluate_ast(expression.step, state, static_tools, custom_tools) if expression.step is not None else None,
+ )
+ elif isinstance(expression, ast.DictComp):
+ return evaluate_dictcomp(expression, state, static_tools, custom_tools)
+ elif isinstance(expression, ast.While):
+ return evaluate_while(expression, state, static_tools, custom_tools)
+ elif isinstance(expression, (ast.Import, ast.ImportFrom)):
+ return import_modules(expression, state, authorized_imports)
+ elif isinstance(expression, ast.ClassDef):
+ return evaluate_class_def(expression, state, static_tools, custom_tools)
+ elif isinstance(expression, ast.Try):
+ return evaluate_try(expression, state, static_tools, custom_tools)
+ elif isinstance(expression, ast.Raise):
+ return evaluate_raise(expression, state, static_tools, custom_tools)
+ elif isinstance(expression, ast.Assert):
+ return evaluate_assert(expression, state, static_tools, custom_tools)
+ elif isinstance(expression, ast.With):
+ return evaluate_with(expression, state, static_tools, custom_tools)
+ elif isinstance(expression, ast.Set):
+ return {evaluate_ast(elt, state, static_tools, custom_tools) for elt in expression.elts}
+ elif isinstance(expression, ast.Return):
+ raise ReturnException(
+ evaluate_ast(expression.value, state, static_tools, custom_tools) if expression.value else None
+ )
+ else:
+ # For now we refuse anything else. Let's add things as we need them.
+ raise InterpreterError(f"{expression.__class__.__name__} is not supported.")
+
+
+def truncate_print_outputs(print_outputs: str, max_len_outputs: int = MAX_LEN_OUTPUT) -> str:
+ if len(print_outputs) < max_len_outputs:
+ return print_outputs
+ else:
+ return f"Print outputs:\n{print_outputs[:max_len_outputs]}\n_Print outputs have been truncated over the limit of {max_len_outputs} characters._\n"
+
+
+def evaluate_python_code(
+ code: str,
+ static_tools: Optional[Dict[str, Callable]] = None,
+ custom_tools: Optional[Dict[str, Callable]] = None,
+ state: Optional[Dict[str, Any]] = None,
+ authorized_imports: List[str] = LIST_SAFE_MODULES,
+):
+ """
+ Evaluate a python expression using the content of the variables stored in a state and only evaluating a given set
+ of functions.
+
+ This function will recurse through the nodes of the tree provided.
+
+ Args:
+ code (`str`):
+ The code to evaluate.
+ static_tools (`Dict[str, Callable]`):
+ The functions that may be called during the evaluation.
+ These tools cannot be overwritten in the code: any assignment to their name will raise an error.
+ custom_tools (`Dict[str, Callable]`):
+ The functions that may be called during the evaluation.
+ These tools can be overwritten in the code: any assignment to their name will overwrite them.
+ state (`Dict[str, Any]`):
+ A dictionary mapping variable names to values. The `state` should contain the initial inputs but will be
+ updated by this function to contain all variables as they are evaluated.
+ The print outputs will be stored in the state under the key 'print_outputs'.
+ """
+ try:
+ expression = ast.parse(code)
+ except SyntaxError as e:
+ raise SyntaxError(f"The code generated by the agent is not valid.\n{e}")
+ if state is None:
+ state = {}
+ if static_tools is None:
+ static_tools = {}
+ if custom_tools is None:
+ custom_tools = {}
+ result = None
+ global PRINT_OUTPUTS
+ PRINT_OUTPUTS = ""
+ global OPERATIONS_COUNT
+ OPERATIONS_COUNT = 0
+ try:
+ for node in expression.body:
+ result = evaluate_ast(node, state, static_tools, custom_tools, authorized_imports)
+ state["print_outputs"] = truncate_print_outputs(PRINT_OUTPUTS, max_len_outputs=MAX_LEN_OUTPUT)
+ return result
+ except InterpreterError as e:
+ msg = truncate_print_outputs(PRINT_OUTPUTS, max_len_outputs=MAX_LEN_OUTPUT)
+ msg += f"EXECUTION FAILED:\nEvaluation stopped at line '{ast.get_source_segment(code, node)}' because of the following error:\n{e}"
+ raise InterpreterError(msg)
diff --git a/agents/search.py b/agents/search.py
new file mode 100644
index 0000000..1c2c339
--- /dev/null
+++ b/agents/search.py
@@ -0,0 +1,77 @@
+#!/usr/bin/env python
+# coding=utf-8
+
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import re
+
+import requests
+from requests.exceptions import RequestException
+
+from .tools import Tool
+
+
+class DuckDuckGoSearchTool(Tool):
+ name = "web_search"
+ description = """Perform a web search based on your query (think a Google search) then returns the top search results as a list of dict elements.
+ Each result has keys 'title', 'href' and 'body'."""
+ inputs = {"query": {"type": "string", "description": "The search query to perform."}}
+ output_type = "any"
+
+ def forward(self, query: str) -> str:
+ try:
+ from duckduckgo_search import DDGS
+ except ImportError:
+ raise ImportError(
+ "You must install package `duckduckgo_search` to run this tool: for instance run `pip install duckduckgo-search`."
+ )
+ results = DDGS().text(query, max_results=7)
+ return results
+
+
+class VisitWebpageTool(Tool):
+ name = "visit_webpage"
+ description = "Visits a webpage at the given url and returns its content as a markdown string."
+ inputs = {
+ "url": {
+ "type": "string",
+ "description": "The url of the webpage to visit.",
+ }
+ }
+ output_type = "string"
+
+ def forward(self, url: str) -> str:
+ try:
+ from markdownify import markdownify
+ except ImportError:
+ raise ImportError(
+ "You must install package `markdownify` to run this tool: for instance run `pip install markdownify`."
+ )
+ try:
+ # Send a GET request to the URL
+ response = requests.get(url)
+ response.raise_for_status() # Raise an exception for bad status codes
+
+ # Convert the HTML content to Markdown
+ markdown_content = markdownify(response.text).strip()
+
+ # Remove multiple line breaks
+ markdown_content = re.sub(r"\n{3,}", "\n\n", markdown_content)
+
+ return markdown_content
+
+ except RequestException as e:
+ return f"Error fetching the webpage: {str(e)}"
+ except Exception as e:
+ return f"An unexpected error occurred: {str(e)}"
diff --git a/agents/speech_to_text.py b/agents/speech_to_text.py
new file mode 100644
index 0000000..8061651
--- /dev/null
+++ b/agents/speech_to_text.py
@@ -0,0 +1,39 @@
+#!/usr/bin/env python
+# coding=utf-8
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ..models.whisper import WhisperForConditionalGeneration, WhisperProcessor
+from .tools import PipelineTool
+
+
+class SpeechToTextTool(PipelineTool):
+ default_checkpoint = "distil-whisper/distil-large-v3"
+ description = "This is a tool that transcribes an audio into text. It returns the transcribed text."
+ name = "transcriber"
+ pre_processor_class = WhisperProcessor
+ model_class = WhisperForConditionalGeneration
+
+ inputs = {"audio": {"type": "audio", "description": "The audio to transcribe"}}
+ output_type = "string"
+
+ def encode(self, audio):
+ return self.pre_processor(audio, return_tensors="pt")
+
+ def forward(self, inputs):
+ return self.model.generate(inputs["input_features"])
+
+ def decode(self, outputs):
+ return self.pre_processor.batch_decode(outputs, skip_special_tokens=True)[0]
diff --git a/agents/text_to_speech.py b/agents/text_to_speech.py
new file mode 100644
index 0000000..ed41ef6
--- /dev/null
+++ b/agents/text_to_speech.py
@@ -0,0 +1,67 @@
+#!/usr/bin/env python
+# coding=utf-8
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+
+from ..models.speecht5 import SpeechT5ForTextToSpeech, SpeechT5HifiGan, SpeechT5Processor
+from ..utils import is_datasets_available
+from .tools import PipelineTool
+
+
+if is_datasets_available():
+ from datasets import load_dataset
+
+
+class TextToSpeechTool(PipelineTool):
+ default_checkpoint = "microsoft/speecht5_tts"
+ description = (
+ "This is a tool that reads an English text out loud. It returns a waveform object containing the sound."
+ )
+ name = "text_to_speech"
+ pre_processor_class = SpeechT5Processor
+ model_class = SpeechT5ForTextToSpeech
+ post_processor_class = SpeechT5HifiGan
+
+ inputs = {"text": {"type": "string", "description": "The text to read out loud (in English)"}}
+ output_type = "audio"
+
+ def setup(self):
+ if self.post_processor is None:
+ self.post_processor = "microsoft/speecht5_hifigan"
+ super().setup()
+
+ def encode(self, text, speaker_embeddings=None):
+ inputs = self.pre_processor(text=text, return_tensors="pt", truncation=True)
+
+ if speaker_embeddings is None:
+ if not is_datasets_available():
+ raise ImportError("Datasets needs to be installed if not passing speaker embeddings.")
+
+ embeddings_dataset = load_dataset(
+ "Matthijs/cmu-arctic-xvectors", split="validation", trust_remote_code=True
+ )
+ speaker_embeddings = torch.tensor(embeddings_dataset[7305]["xvector"]).unsqueeze(0)
+
+ return {"input_ids": inputs["input_ids"], "speaker_embeddings": speaker_embeddings}
+
+ def forward(self, inputs):
+ with torch.no_grad():
+ return self.model.generate_speech(**inputs)
+
+ def decode(self, outputs):
+ with torch.no_grad():
+ return self.post_processor(outputs).cpu().detach()
diff --git a/agents/tools.py b/agents/tools.py
new file mode 100644
index 0000000..7597046
--- /dev/null
+++ b/agents/tools.py
@@ -0,0 +1,1003 @@
+#!/usr/bin/env python
+# coding=utf-8
+
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import ast
+import base64
+import importlib
+import inspect
+import io
+import json
+import os
+import tempfile
+from functools import lru_cache, wraps
+from pathlib import Path
+from typing import Any, Callable, Dict, List, Optional, Union
+
+from huggingface_hub import create_repo, get_collection, hf_hub_download, metadata_update, upload_folder
+from huggingface_hub.utils import RepositoryNotFoundError, build_hf_headers, get_session
+from packaging import version
+
+from ..dynamic_module_utils import (
+ custom_object_save,
+ get_class_from_dynamic_module,
+ get_imports,
+)
+from ..models.auto import AutoProcessor
+from ..utils import (
+ CONFIG_NAME,
+ TypeHintParsingException,
+ cached_file,
+ get_json_schema,
+ is_accelerate_available,
+ is_torch_available,
+ is_vision_available,
+ logging,
+)
+from .agent_types import ImageType, handle_agent_inputs, handle_agent_outputs
+
+
+logger = logging.get_logger(__name__)
+
+
+if is_torch_available():
+ import torch
+
+if is_accelerate_available():
+ from accelerate import PartialState
+ from accelerate.utils import send_to_device
+
+
+TOOL_CONFIG_FILE = "tool_config.json"
+
+
+def get_repo_type(repo_id, repo_type=None, **hub_kwargs):
+ if repo_type is not None:
+ return repo_type
+ try:
+ hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="space", **hub_kwargs)
+ return "space"
+ except RepositoryNotFoundError:
+ try:
+ hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="model", **hub_kwargs)
+ return "model"
+ except RepositoryNotFoundError:
+ raise EnvironmentError(f"`{repo_id}` does not seem to be a valid repo identifier on the Hub.")
+ except Exception:
+ return "model"
+ except Exception:
+ return "space"
+
+
+# docstyle-ignore
+APP_FILE_TEMPLATE = """from transformers import launch_gradio_demo
+from {module_name} import {class_name}
+
+launch_gradio_demo({class_name})
+"""
+
+
+def validate_after_init(cls, do_validate_forward: bool = True):
+ original_init = cls.__init__
+
+ @wraps(original_init)
+ def new_init(self, *args, **kwargs):
+ original_init(self, *args, **kwargs)
+ if not isinstance(self, PipelineTool):
+ self.validate_arguments(do_validate_forward=do_validate_forward)
+
+ cls.__init__ = new_init
+ return cls
+
+
+CONVERSION_DICT = {"str": "string", "int": "integer", "float": "number"}
+
+
+class Tool:
+ """
+ A base class for the functions used by the agent. Subclass this and implement the `__call__` method as well as the
+ following class attributes:
+
+ - **description** (`str`) -- A short description of what your tool does, the inputs it expects and the output(s) it
+ will return. For instance 'This is a tool that downloads a file from a `url`. It takes the `url` as input, and
+ returns the text contained in the file'.
+ - **name** (`str`) -- A performative name that will be used for your tool in the prompt to the agent. For instance
+ `"text-classifier"` or `"image_generator"`.
+ - **inputs** (`Dict[str, Dict[str, Union[str, type]]]`) -- The dict of modalities expected for the inputs.
+ It has one `type`key and a `description`key.
+ This is used by `launch_gradio_demo` or to make a nice space from your tool, and also can be used in the generated
+ description for your tool.
+ - **output_type** (`type`) -- The type of the tool output. This is used by `launch_gradio_demo`
+ or to make a nice space from your tool, and also can be used in the generated description for your tool.
+
+ You can also override the method [`~Tool.setup`] if your tool as an expensive operation to perform before being
+ usable (such as loading a model). [`~Tool.setup`] will be called the first time you use your tool, but not at
+ instantiation.
+ """
+
+ name: str
+ description: str
+ inputs: Dict[str, Dict[str, Union[str, type]]]
+ output_type: type
+
+ def __init__(self, *args, **kwargs):
+ self.is_initialized = False
+
+ def __init_subclass__(cls, **kwargs):
+ super().__init_subclass__(**kwargs)
+ validate_after_init(cls, do_validate_forward=False)
+
+ def validate_arguments(self, do_validate_forward: bool = True):
+ required_attributes = {
+ "description": str,
+ "name": str,
+ "inputs": dict,
+ "output_type": str,
+ }
+ authorized_types = ["string", "integer", "number", "image", "audio", "any", "boolean"]
+
+ for attr, expected_type in required_attributes.items():
+ attr_value = getattr(self, attr, None)
+ if attr_value is None:
+ raise TypeError(f"You must set an attribute {attr}.")
+ if not isinstance(attr_value, expected_type):
+ raise TypeError(
+ f"Attribute {attr} should have type {expected_type.__name__}, got {type(attr_value)} instead."
+ )
+ for input_name, input_content in self.inputs.items():
+ assert isinstance(input_content, dict), f"Input '{input_name}' should be a dictionary."
+ assert (
+ "type" in input_content and "description" in input_content
+ ), f"Input '{input_name}' should have keys 'type' and 'description', has only {list(input_content.keys())}."
+ if input_content["type"] not in authorized_types:
+ raise Exception(
+ f"Input '{input_name}': type '{input_content['type']}' is not an authorized value, should be one of {authorized_types}."
+ )
+
+ assert getattr(self, "output_type", None) in authorized_types
+ if do_validate_forward:
+ if not isinstance(self, PipelineTool):
+ signature = inspect.signature(self.forward)
+ if not set(signature.parameters.keys()) == set(self.inputs.keys()):
+ raise Exception(
+ "Tool's 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'."
+ )
+
+ def forward(self, *args, **kwargs):
+ return NotImplemented("Write this method in your subclass of `Tool`.")
+
+ def __call__(self, *args, **kwargs):
+ args, kwargs = handle_agent_inputs(*args, **kwargs)
+ outputs = self.forward(*args, **kwargs)
+ return handle_agent_outputs(outputs, self.output_type)
+
+ def setup(self):
+ """
+ Overwrite this method here for any operation that is expensive and needs to be executed before you start using
+ your tool. Such as loading a big model.
+ """
+ self.is_initialized = True
+
+ def save(self, output_dir):
+ """
+ Saves the relevant code files for your tool so it can be pushed to the Hub. This will copy the code of your
+ tool in `output_dir` as well as autogenerate:
+
+ - a config file named `tool_config.json`
+ - an `app.py` file so that your tool can be converted to a space
+ - a `requirements.txt` containing the names of the module used by your tool (as detected when inspecting its
+ code)
+
+ You should only use this method to save tools that are defined in a separate module (not `__main__`).
+
+ Args:
+ output_dir (`str`): The folder in which you want to save your tool.
+ """
+ os.makedirs(output_dir, exist_ok=True)
+ # Save module file
+ if self.__module__ == "__main__":
+ raise ValueError(
+ f"We can't save the code defining {self} in {output_dir} as it's been defined in __main__. You "
+ "have to put this code in a separate module so we can include it in the saved folder."
+ )
+ module_files = custom_object_save(self, output_dir)
+
+ module_name = self.__class__.__module__
+ last_module = module_name.split(".")[-1]
+ full_name = f"{last_module}.{self.__class__.__name__}"
+
+ # Save config file
+ config_file = os.path.join(output_dir, "tool_config.json")
+ if os.path.isfile(config_file):
+ with open(config_file, "r", encoding="utf-8") as f:
+ tool_config = json.load(f)
+ else:
+ tool_config = {}
+
+ tool_config = {
+ "tool_class": full_name,
+ "description": self.description,
+ "name": self.name,
+ "inputs": self.inputs,
+ "output_type": str(self.output_type),
+ }
+ with open(config_file, "w", encoding="utf-8") as f:
+ f.write(json.dumps(tool_config, indent=2, sort_keys=True) + "\n")
+
+ # Save app file
+ app_file = os.path.join(output_dir, "app.py")
+ with open(app_file, "w", encoding="utf-8") as f:
+ f.write(APP_FILE_TEMPLATE.format(module_name=last_module, class_name=self.__class__.__name__))
+
+ # Save requirements file
+ requirements_file = os.path.join(output_dir, "requirements.txt")
+ imports = []
+ for module in module_files:
+ imports.extend(get_imports(module))
+ imports = list(set(imports))
+ with open(requirements_file, "w", encoding="utf-8") as f:
+ f.write("\n".join(imports) + "\n")
+
+ @classmethod
+ def from_hub(
+ cls,
+ repo_id: str,
+ token: Optional[str] = None,
+ **kwargs,
+ ):
+ """
+ Loads a tool defined on the Hub.
+
+
+
+ Loading a tool from the Hub means that you'll download the tool and execute it locally.
+ ALWAYS inspect the tool you're downloading before loading it within your runtime, as you would do when
+ installing a package using pip/npm/apt.
+
+
+
+ Args:
+ repo_id (`str`):
+ The name of the repo on the Hub where your tool is defined.
+ token (`str`, *optional*):
+ The token to identify you on hf.co. If unset, will use the token generated when running
+ `huggingface-cli login` (stored in `~/.huggingface`).
+ kwargs (additional keyword arguments, *optional*):
+ Additional keyword arguments that will be split in two: all arguments relevant to the Hub (such as
+ `cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the
+ others will be passed along to its init.
+ """
+ hub_kwargs_names = [
+ "cache_dir",
+ "force_download",
+ "resume_download",
+ "proxies",
+ "revision",
+ "repo_type",
+ "subfolder",
+ "local_files_only",
+ ]
+ hub_kwargs = {k: v for k, v in kwargs.items() if k in hub_kwargs_names}
+
+ # Try to get the tool config first.
+ hub_kwargs["repo_type"] = get_repo_type(repo_id, **hub_kwargs)
+ resolved_config_file = cached_file(
+ repo_id,
+ TOOL_CONFIG_FILE,
+ token=token,
+ **hub_kwargs,
+ _raise_exceptions_for_gated_repo=False,
+ _raise_exceptions_for_missing_entries=False,
+ _raise_exceptions_for_connection_errors=False,
+ )
+ is_tool_config = resolved_config_file is not None
+ if resolved_config_file is None:
+ resolved_config_file = cached_file(
+ repo_id,
+ CONFIG_NAME,
+ token=token,
+ **hub_kwargs,
+ _raise_exceptions_for_gated_repo=False,
+ _raise_exceptions_for_missing_entries=False,
+ _raise_exceptions_for_connection_errors=False,
+ )
+ if resolved_config_file is None:
+ raise EnvironmentError(
+ f"{repo_id} does not appear to provide a valid configuration in `tool_config.json` or `config.json`."
+ )
+
+ with open(resolved_config_file, encoding="utf-8") as reader:
+ config = json.load(reader)
+
+ if not is_tool_config:
+ if "custom_tool" not in config:
+ raise EnvironmentError(
+ f"{repo_id} does not provide a mapping to custom tools in its configuration `config.json`."
+ )
+ custom_tool = config["custom_tool"]
+ else:
+ custom_tool = config
+
+ tool_class = custom_tool["tool_class"]
+ tool_class = get_class_from_dynamic_module(tool_class, repo_id, token=token, **hub_kwargs)
+
+ if len(tool_class.name) == 0:
+ tool_class.name = custom_tool["name"]
+ if tool_class.name != custom_tool["name"]:
+ logger.warning(
+ f"{tool_class.__name__} implements a different name in its configuration and class. Using the tool "
+ "configuration name."
+ )
+ tool_class.name = custom_tool["name"]
+
+ if len(tool_class.description) == 0:
+ tool_class.description = custom_tool["description"]
+ if tool_class.description != custom_tool["description"]:
+ logger.warning(
+ f"{tool_class.__name__} implements a different description in its configuration and class. Using the "
+ "tool configuration description."
+ )
+ tool_class.description = custom_tool["description"]
+
+ if tool_class.inputs != custom_tool["inputs"]:
+ tool_class.inputs = custom_tool["inputs"]
+ if tool_class.output_type != custom_tool["output_type"]:
+ tool_class.output_type = custom_tool["output_type"]
+
+ if not isinstance(tool_class.inputs, dict):
+ tool_class.inputs = ast.literal_eval(tool_class.inputs)
+
+ return tool_class(**kwargs)
+
+ def push_to_hub(
+ self,
+ repo_id: str,
+ commit_message: str = "Upload tool",
+ private: Optional[bool] = None,
+ token: Optional[Union[bool, str]] = None,
+ create_pr: bool = False,
+ ) -> str:
+ """
+ Upload the tool to the Hub.
+
+ For this method to work properly, your tool must have been defined in a separate module (not `__main__`).
+ For instance:
+ ```
+ from my_tool_module import MyTool
+ my_tool = MyTool()
+ my_tool.push_to_hub("my-username/my-space")
+ ```
+
+ Parameters:
+ repo_id (`str`):
+ The name of the repository you want to push your tool to. It should contain your organization name when
+ pushing to a given organization.
+ commit_message (`str`, *optional*, defaults to `"Upload tool"`):
+ Message to commit while pushing.
+ private (`bool`, *optional*):
+ Whether to make the repo private. If `None` (default), the repo will be public unless the organization's default is private. This value is ignored if the repo already exists.
+ token (`bool` or `str`, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated
+ when running `huggingface-cli login` (stored in `~/.huggingface`).
+ create_pr (`bool`, *optional*, defaults to `False`):
+ Whether or not to create a PR with the uploaded files or directly commit.
+ """
+ repo_url = create_repo(
+ repo_id=repo_id,
+ token=token,
+ private=private,
+ exist_ok=True,
+ repo_type="space",
+ space_sdk="gradio",
+ )
+ repo_id = repo_url.repo_id
+ metadata_update(repo_id, {"tags": ["tool"]}, repo_type="space")
+
+ with tempfile.TemporaryDirectory() as work_dir:
+ # Save all files.
+ self.save(work_dir)
+ logger.info(f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}")
+ return upload_folder(
+ repo_id=repo_id,
+ commit_message=commit_message,
+ folder_path=work_dir,
+ token=token,
+ create_pr=create_pr,
+ repo_type="space",
+ )
+
+ @staticmethod
+ def from_space(
+ space_id: str, name: str, description: str, api_name: Optional[str] = None, token: Optional[str] = None
+ ):
+ """
+ Creates a [`Tool`] from a Space given its id on the Hub.
+
+ Args:
+ space_id (`str`):
+ The id of the Space on the Hub.
+ name (`str`):
+ The name of the tool.
+ description (`str`):
+ The description of the tool.
+ api_name (`str`, *optional*):
+ The specific api_name to use, if the space has several tabs. If not precised, will default to the first available api.
+ token (`str`, *optional*):
+ Add your token to access private spaces or increase your GPU quotas.
+ Returns:
+ [`Tool`]:
+ The Space, as a tool.
+
+ Examples:
+ ```
+ image_generator = Tool.from_space(
+ space_id="black-forest-labs/FLUX.1-schnell",
+ name="image-generator",
+ description="Generate an image from a prompt"
+ )
+ image = image_generator("Generate an image of a cool surfer in Tahiti")
+ ```
+ ```
+ face_swapper = Tool.from_space(
+ "tuan2308/face-swap",
+ "face_swapper",
+ "Tool that puts the face shown on the first image on the second image. You can give it paths to images.",
+ )
+ image = face_swapper('./aymeric.jpeg', './ruth.jpg')
+ ```
+ """
+ from gradio_client import Client, handle_file
+ from gradio_client.utils import is_http_url_like
+
+ class SpaceToolWrapper(Tool):
+ def __init__(
+ self,
+ space_id: str,
+ name: str,
+ description: str,
+ api_name: Optional[str] = None,
+ token: Optional[str] = None,
+ ):
+ self.client = Client(space_id, hf_token=token)
+ self.name = name
+ self.description = description
+ space_description = self.client.view_api(return_format="dict", print_info=False)["named_endpoints"]
+
+ # If api_name is not defined, take the first of the available APIs for this space
+ if api_name is None:
+ api_name = list(space_description.keys())[0]
+ logger.warning(
+ f"Since `api_name` was not defined, it was automatically set to the first avilable API: `{api_name}`."
+ )
+ self.api_name = api_name
+
+ try:
+ space_description_api = space_description[api_name]
+ except KeyError:
+ raise KeyError(f"Could not find specified {api_name=} among available api names.")
+
+ self.inputs = {}
+ for parameter in space_description_api["parameters"]:
+ if not parameter["parameter_has_default"]:
+ parameter_type = parameter["type"]["type"]
+ if parameter_type == "object":
+ parameter_type = "any"
+ self.inputs[parameter["parameter_name"]] = {
+ "type": parameter_type,
+ "description": parameter["python_type"]["description"],
+ }
+ output_component = space_description_api["returns"][0]["component"]
+ if output_component == "Image":
+ self.output_type = "image"
+ elif output_component == "Audio":
+ self.output_type = "audio"
+ else:
+ self.output_type = "any"
+
+ def sanitize_argument_for_prediction(self, arg):
+ if isinstance(arg, ImageType):
+ temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
+ arg.save(temp_file.name)
+ arg = temp_file.name
+ if (isinstance(arg, (str, Path)) and Path(arg).exists() and Path(arg).is_file()) or is_http_url_like(
+ arg
+ ):
+ arg = handle_file(arg)
+ return arg
+
+ def forward(self, *args, **kwargs):
+ # Preprocess args and kwargs:
+ args = list(args)
+ for i, arg in enumerate(args):
+ args[i] = self.sanitize_argument_for_prediction(arg)
+ for arg_name, arg in kwargs.items():
+ kwargs[arg_name] = self.sanitize_argument_for_prediction(arg)
+
+ output = self.client.predict(*args, api_name=self.api_name, **kwargs)
+ if isinstance(output, tuple) or isinstance(output, list):
+ return output[
+ 0
+ ] # Sometime the space also returns the generation seed, in which case the result is at index 0
+ return output
+
+ return SpaceToolWrapper(space_id, name, description, api_name=api_name, token=token)
+
+ @staticmethod
+ def from_gradio(gradio_tool):
+ """
+ Creates a [`Tool`] from a gradio tool.
+ """
+ import inspect
+
+ class GradioToolWrapper(Tool):
+ def __init__(self, _gradio_tool):
+ self.name = _gradio_tool.name
+ self.description = _gradio_tool.description
+ self.output_type = "string"
+ self._gradio_tool = _gradio_tool
+ func_args = list(inspect.signature(_gradio_tool.run).parameters.items())
+ self.inputs = {
+ key: {"type": CONVERSION_DICT[value.annotation], "description": ""} for key, value in func_args
+ }
+ self.forward = self._gradio_tool.run
+
+ return GradioToolWrapper(gradio_tool)
+
+ @staticmethod
+ def from_langchain(langchain_tool):
+ """
+ Creates a [`Tool`] from a langchain tool.
+ """
+
+ class LangChainToolWrapper(Tool):
+ def __init__(self, _langchain_tool):
+ self.name = _langchain_tool.name.lower()
+ self.description = _langchain_tool.description
+ self.inputs = _langchain_tool.args.copy()
+ for input_content in self.inputs.values():
+ if "title" in input_content:
+ input_content.pop("title")
+ input_content["description"] = ""
+ self.output_type = "string"
+ self.langchain_tool = _langchain_tool
+
+ def forward(self, *args, **kwargs):
+ tool_input = kwargs.copy()
+ for index, argument in enumerate(args):
+ if index < len(self.inputs):
+ input_key = next(iter(self.inputs))
+ tool_input[input_key] = argument
+ return self.langchain_tool.run(tool_input)
+
+ return LangChainToolWrapper(langchain_tool)
+
+
+DEFAULT_TOOL_DESCRIPTION_TEMPLATE = """
+- {{ tool.name }}: {{ tool.description }}
+ Takes inputs: {{tool.inputs}}
+ Returns an output of type: {{tool.output_type}}
+"""
+
+
+def get_tool_description_with_args(tool: Tool, description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE) -> str:
+ compiled_template = compile_jinja_template(description_template)
+ rendered = compiled_template.render(
+ tool=tool,
+ )
+ return rendered
+
+
+@lru_cache
+def compile_jinja_template(template):
+ try:
+ import jinja2
+ from jinja2.exceptions import TemplateError
+ from jinja2.sandbox import ImmutableSandboxedEnvironment
+ except ImportError:
+ raise ImportError("template requires jinja2 to be installed.")
+
+ if version.parse(jinja2.__version__) < version.parse("3.1.0"):
+ raise ImportError("template requires jinja2>=3.1.0 to be installed. Your version is " f"{jinja2.__version__}.")
+
+ def raise_exception(message):
+ raise TemplateError(message)
+
+ jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True)
+ jinja_env.globals["raise_exception"] = raise_exception
+ return jinja_env.from_string(template)
+
+
+class PipelineTool(Tool):
+ """
+ A [`Tool`] tailored towards Transformer models. On top of the class attributes of the base class [`Tool`], you will
+ need to specify:
+
+ - **model_class** (`type`) -- The class to use to load the model in this tool.
+ - **default_checkpoint** (`str`) -- The default checkpoint that should be used when the user doesn't specify one.
+ - **pre_processor_class** (`type`, *optional*, defaults to [`AutoProcessor`]) -- The class to use to load the
+ pre-processor
+ - **post_processor_class** (`type`, *optional*, defaults to [`AutoProcessor`]) -- The class to use to load the
+ post-processor (when different from the pre-processor).
+
+ Args:
+ model (`str` or [`PreTrainedModel`], *optional*):
+ The name of the checkpoint to use for the model, or the instantiated model. If unset, will default to the
+ value of the class attribute `default_checkpoint`.
+ pre_processor (`str` or `Any`, *optional*):
+ The name of the checkpoint to use for the pre-processor, or the instantiated pre-processor (can be a
+ tokenizer, an image processor, a feature extractor or a processor). Will default to the value of `model` if
+ unset.
+ post_processor (`str` or `Any`, *optional*):
+ The name of the checkpoint to use for the post-processor, or the instantiated pre-processor (can be a
+ tokenizer, an image processor, a feature extractor or a processor). Will default to the `pre_processor` if
+ unset.
+ device (`int`, `str` or `torch.device`, *optional*):
+ The device on which to execute the model. Will default to any accelerator available (GPU, MPS etc...), the
+ CPU otherwise.
+ device_map (`str` or `dict`, *optional*):
+ If passed along, will be used to instantiate the model.
+ model_kwargs (`dict`, *optional*):
+ Any keyword argument to send to the model instantiation.
+ token (`str`, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated when
+ running `huggingface-cli login` (stored in `~/.huggingface`).
+ hub_kwargs (additional keyword arguments, *optional*):
+ Any additional keyword argument to send to the methods that will load the data from the Hub.
+ """
+
+ pre_processor_class = AutoProcessor
+ model_class = None
+ post_processor_class = AutoProcessor
+ default_checkpoint = None
+ description = "This is a pipeline tool"
+ name = "pipeline"
+ inputs = {"prompt": str}
+ output_type = str
+
+ def __init__(
+ self,
+ model=None,
+ pre_processor=None,
+ post_processor=None,
+ device=None,
+ device_map=None,
+ model_kwargs=None,
+ token=None,
+ **hub_kwargs,
+ ):
+ if not is_torch_available():
+ raise ImportError("Please install torch in order to use this tool.")
+
+ if not is_accelerate_available():
+ raise ImportError("Please install accelerate in order to use this tool.")
+
+ if model is None:
+ if self.default_checkpoint is None:
+ raise ValueError("This tool does not implement a default checkpoint, you need to pass one.")
+ model = self.default_checkpoint
+ if pre_processor is None:
+ pre_processor = model
+
+ self.model = model
+ self.pre_processor = pre_processor
+ self.post_processor = post_processor
+ self.device = device
+ self.device_map = device_map
+ self.model_kwargs = {} if model_kwargs is None else model_kwargs
+ if device_map is not None:
+ self.model_kwargs["device_map"] = device_map
+ self.hub_kwargs = hub_kwargs
+ self.hub_kwargs["token"] = token
+
+ super().__init__()
+
+ def setup(self):
+ """
+ Instantiates the `pre_processor`, `model` and `post_processor` if necessary.
+ """
+ if isinstance(self.pre_processor, str):
+ self.pre_processor = self.pre_processor_class.from_pretrained(self.pre_processor, **self.hub_kwargs)
+
+ if isinstance(self.model, str):
+ self.model = self.model_class.from_pretrained(self.model, **self.model_kwargs, **self.hub_kwargs)
+
+ if self.post_processor is None:
+ self.post_processor = self.pre_processor
+ elif isinstance(self.post_processor, str):
+ self.post_processor = self.post_processor_class.from_pretrained(self.post_processor, **self.hub_kwargs)
+
+ if self.device is None:
+ if self.device_map is not None:
+ self.device = list(self.model.hf_device_map.values())[0]
+ else:
+ self.device = PartialState().default_device
+
+ if self.device_map is None:
+ self.model.to(self.device)
+
+ super().setup()
+
+ def encode(self, raw_inputs):
+ """
+ Uses the `pre_processor` to prepare the inputs for the `model`.
+ """
+ return self.pre_processor(raw_inputs)
+
+ def forward(self, inputs):
+ """
+ Sends the inputs through the `model`.
+ """
+ with torch.no_grad():
+ return self.model(**inputs)
+
+ def decode(self, outputs):
+ """
+ Uses the `post_processor` to decode the model output.
+ """
+ return self.post_processor(outputs)
+
+ def __call__(self, *args, **kwargs):
+ args, kwargs = handle_agent_inputs(*args, **kwargs)
+
+ if not self.is_initialized:
+ self.setup()
+
+ encoded_inputs = self.encode(*args, **kwargs)
+
+ tensor_inputs = {k: v for k, v in encoded_inputs.items() if isinstance(v, torch.Tensor)}
+ non_tensor_inputs = {k: v for k, v in encoded_inputs.items() if not isinstance(v, torch.Tensor)}
+
+ encoded_inputs = send_to_device(tensor_inputs, self.device)
+ outputs = self.forward({**encoded_inputs, **non_tensor_inputs})
+ outputs = send_to_device(outputs, "cpu")
+ decoded_outputs = self.decode(outputs)
+
+ return handle_agent_outputs(decoded_outputs, self.output_type)
+
+
+def launch_gradio_demo(tool_class: Tool):
+ """
+ Launches a gradio demo for a tool. The corresponding tool class needs to properly implement the class attributes
+ `inputs` and `output_type`.
+
+ Args:
+ tool_class (`type`): The class of the tool for which to launch the demo.
+ """
+ try:
+ import gradio as gr
+ except ImportError:
+ raise ImportError("Gradio should be installed in order to launch a gradio demo.")
+
+ tool = tool_class()
+
+ def fn(*args, **kwargs):
+ return tool(*args, **kwargs)
+
+ TYPE_TO_COMPONENT_CLASS_MAPPING = {
+ "image": gr.Image,
+ "audio": gr.Audio,
+ "string": gr.Textbox,
+ "integer": gr.Textbox,
+ "number": gr.Textbox,
+ }
+
+ gradio_inputs = []
+ for input_name, input_details in tool_class.inputs.items():
+ input_gradio_component_class = TYPE_TO_COMPONENT_CLASS_MAPPING[input_details["type"]]
+ new_component = input_gradio_component_class(label=input_name)
+ gradio_inputs.append(new_component)
+
+ output_gradio_componentclass = TYPE_TO_COMPONENT_CLASS_MAPPING[tool_class.output_type]
+ gradio_output = output_gradio_componentclass(label=input_name)
+
+ gr.Interface(
+ fn=fn,
+ inputs=gradio_inputs,
+ outputs=gradio_output,
+ title=tool_class.__name__,
+ article=tool.description,
+ ).launch()
+
+
+TOOL_MAPPING = {
+ "document_question_answering": "DocumentQuestionAnsweringTool",
+ "image_question_answering": "ImageQuestionAnsweringTool",
+ "speech_to_text": "SpeechToTextTool",
+ "text_to_speech": "TextToSpeechTool",
+ "translation": "TranslationTool",
+ "python_interpreter": "PythonInterpreterTool",
+ "web_search": "DuckDuckGoSearchTool",
+}
+
+
+def load_tool(task_or_repo_id, model_repo_id=None, token=None, **kwargs):
+ """
+ Main function to quickly load a tool, be it on the Hub or in the Transformers library.
+
+
+
+ Loading a tool means that you'll download the tool and execute it locally.
+ ALWAYS inspect the tool you're downloading before loading it within your runtime, as you would do when
+ installing a package using pip/npm/apt.
+
+
+
+ Args:
+ task_or_repo_id (`str`):
+ The task for which to load the tool or a repo ID of a tool on the Hub. Tasks implemented in Transformers
+ are:
+
+ - `"document_question_answering"`
+ - `"image_question_answering"`
+ - `"speech_to_text"`
+ - `"text_to_speech"`
+ - `"translation"`
+
+ model_repo_id (`str`, *optional*):
+ Use this argument to use a different model than the default one for the tool you selected.
+ token (`str`, *optional*):
+ The token to identify you on hf.co. If unset, will use the token generated when running `huggingface-cli
+ login` (stored in `~/.huggingface`).
+ kwargs (additional keyword arguments, *optional*):
+ Additional keyword arguments that will be split in two: all arguments relevant to the Hub (such as
+ `cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the others
+ will be passed along to its init.
+ """
+ if task_or_repo_id in TOOL_MAPPING:
+ tool_class_name = TOOL_MAPPING[task_or_repo_id]
+ main_module = importlib.import_module("transformers")
+ tools_module = main_module.agents
+ tool_class = getattr(tools_module, tool_class_name)
+ return tool_class(model_repo_id, token=token, **kwargs)
+ else:
+ logger.warning_once(
+ f"You're loading a tool from the Hub from {model_repo_id}. Please make sure this is a source that you "
+ f"trust as the code within that tool will be executed on your machine. Always verify the code of "
+ f"the tools that you load. We recommend specifying a `revision` to ensure you're loading the "
+ f"code that you have checked."
+ )
+ return Tool.from_hub(task_or_repo_id, model_repo_id=model_repo_id, token=token, **kwargs)
+
+
+def add_description(description):
+ """
+ A decorator that adds a description to a function.
+ """
+
+ def inner(func):
+ func.description = description
+ func.name = func.__name__
+ return func
+
+ return inner
+
+
+## Will move to the Hub
+class EndpointClient:
+ def __init__(self, endpoint_url: str, token: Optional[str] = None):
+ self.headers = {
+ **build_hf_headers(token=token),
+ "Content-Type": "application/json",
+ }
+ self.endpoint_url = endpoint_url
+
+ @staticmethod
+ def encode_image(image):
+ _bytes = io.BytesIO()
+ image.save(_bytes, format="PNG")
+ b64 = base64.b64encode(_bytes.getvalue())
+ return b64.decode("utf-8")
+
+ @staticmethod
+ def decode_image(raw_image):
+ if not is_vision_available():
+ raise ImportError(
+ "This tool returned an image but Pillow is not installed. Please install it (`pip install Pillow`)."
+ )
+
+ from PIL import Image
+
+ b64 = base64.b64decode(raw_image)
+ _bytes = io.BytesIO(b64)
+ return Image.open(_bytes)
+
+ def __call__(
+ self,
+ inputs: Optional[Union[str, Dict, List[str], List[List[str]]]] = None,
+ params: Optional[Dict] = None,
+ data: Optional[bytes] = None,
+ output_image: bool = False,
+ ) -> Any:
+ # Build payload
+ payload = {}
+ if inputs:
+ payload["inputs"] = inputs
+ if params:
+ payload["parameters"] = params
+
+ # Make API call
+ response = get_session().post(self.endpoint_url, headers=self.headers, json=payload, data=data)
+
+ # By default, parse the response for the user.
+ if output_image:
+ return self.decode_image(response.content)
+ else:
+ return response.json()
+
+
+class ToolCollection:
+ """
+ Tool collections enable loading all Spaces from a collection in order to be added to the agent's toolbox.
+
+ > [!NOTE]
+ > Only Spaces will be fetched, so you can feel free to add models and datasets to your collection if you'd
+ > like for this collection to showcase them.
+
+ Args:
+ collection_slug (str):
+ The collection slug referencing the collection.
+ token (str, *optional*):
+ The authentication token if the collection is private.
+
+ Example:
+
+ ```py
+ >>> from transformers import ToolCollection, ReactCodeAgent
+
+ >>> image_tool_collection = ToolCollection(collection_slug="huggingface-tools/diffusion-tools-6630bb19a942c2306a2cdb6f")
+ >>> agent = ReactCodeAgent(tools=[*image_tool_collection.tools], add_base_tools=True)
+
+ >>> agent.run("Please draw me a picture of rivers and lakes.")
+ ```
+ """
+
+ def __init__(self, collection_slug: str, token: Optional[str] = None):
+ self._collection = get_collection(collection_slug, token=token)
+ self._hub_repo_ids = {item.item_id for item in self._collection.items if item.item_type == "space"}
+ self.tools = {Tool.from_hub(repo_id) for repo_id in self._hub_repo_ids}
+
+
+def tool(tool_function: Callable) -> Tool:
+ """
+ Converts a function into an instance of a Tool subclass.
+
+ Args:
+ tool_function: Your function. Should have type hints for each input and a type hint for the output.
+ Should also have a docstring description including an 'Args:' part where each argument is described.
+ """
+ parameters = get_json_schema(tool_function)["function"]
+ if "return" not in parameters:
+ raise TypeHintParsingException("Tool return type not found: make sure your function has a return type hint!")
+ class_name = f"{parameters['name'].capitalize()}Tool"
+
+ class SpecificTool(Tool):
+ name = parameters["name"]
+ description = parameters["description"]
+ inputs = parameters["parameters"]["properties"]
+ output_type = parameters["return"]["type"]
+
+ @wraps(tool_function)
+ def forward(self, *args, **kwargs):
+ return tool_function(*args, **kwargs)
+
+ original_signature = inspect.signature(tool_function)
+ new_parameters = [inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD)] + list(
+ original_signature.parameters.values()
+ )
+ new_signature = original_signature.replace(parameters=new_parameters)
+ SpecificTool.forward.__signature__ = new_signature
+
+ SpecificTool.__name__ = class_name
+ return SpecificTool()
diff --git a/agents/translation.py b/agents/translation.py
new file mode 100644
index 0000000..7ae61f9
--- /dev/null
+++ b/agents/translation.py
@@ -0,0 +1,279 @@
+#!/usr/bin/env python
+# coding=utf-8
+
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from ..models.auto import AutoModelForSeq2SeqLM, AutoTokenizer
+from .tools import PipelineTool
+
+
+LANGUAGE_CODES = {
+ "Acehnese Arabic": "ace_Arab",
+ "Acehnese Latin": "ace_Latn",
+ "Mesopotamian Arabic": "acm_Arab",
+ "Ta'izzi-Adeni Arabic": "acq_Arab",
+ "Tunisian Arabic": "aeb_Arab",
+ "Afrikaans": "afr_Latn",
+ "South Levantine Arabic": "ajp_Arab",
+ "Akan": "aka_Latn",
+ "Amharic": "amh_Ethi",
+ "North Levantine Arabic": "apc_Arab",
+ "Modern Standard Arabic": "arb_Arab",
+ "Modern Standard Arabic Romanized": "arb_Latn",
+ "Najdi Arabic": "ars_Arab",
+ "Moroccan Arabic": "ary_Arab",
+ "Egyptian Arabic": "arz_Arab",
+ "Assamese": "asm_Beng",
+ "Asturian": "ast_Latn",
+ "Awadhi": "awa_Deva",
+ "Central Aymara": "ayr_Latn",
+ "South Azerbaijani": "azb_Arab",
+ "North Azerbaijani": "azj_Latn",
+ "Bashkir": "bak_Cyrl",
+ "Bambara": "bam_Latn",
+ "Balinese": "ban_Latn",
+ "Belarusian": "bel_Cyrl",
+ "Bemba": "bem_Latn",
+ "Bengali": "ben_Beng",
+ "Bhojpuri": "bho_Deva",
+ "Banjar Arabic": "bjn_Arab",
+ "Banjar Latin": "bjn_Latn",
+ "Standard Tibetan": "bod_Tibt",
+ "Bosnian": "bos_Latn",
+ "Buginese": "bug_Latn",
+ "Bulgarian": "bul_Cyrl",
+ "Catalan": "cat_Latn",
+ "Cebuano": "ceb_Latn",
+ "Czech": "ces_Latn",
+ "Chokwe": "cjk_Latn",
+ "Central Kurdish": "ckb_Arab",
+ "Crimean Tatar": "crh_Latn",
+ "Welsh": "cym_Latn",
+ "Danish": "dan_Latn",
+ "German": "deu_Latn",
+ "Southwestern Dinka": "dik_Latn",
+ "Dyula": "dyu_Latn",
+ "Dzongkha": "dzo_Tibt",
+ "Greek": "ell_Grek",
+ "English": "eng_Latn",
+ "Esperanto": "epo_Latn",
+ "Estonian": "est_Latn",
+ "Basque": "eus_Latn",
+ "Ewe": "ewe_Latn",
+ "Faroese": "fao_Latn",
+ "Fijian": "fij_Latn",
+ "Finnish": "fin_Latn",
+ "Fon": "fon_Latn",
+ "French": "fra_Latn",
+ "Friulian": "fur_Latn",
+ "Nigerian Fulfulde": "fuv_Latn",
+ "Scottish Gaelic": "gla_Latn",
+ "Irish": "gle_Latn",
+ "Galician": "glg_Latn",
+ "Guarani": "grn_Latn",
+ "Gujarati": "guj_Gujr",
+ "Haitian Creole": "hat_Latn",
+ "Hausa": "hau_Latn",
+ "Hebrew": "heb_Hebr",
+ "Hindi": "hin_Deva",
+ "Chhattisgarhi": "hne_Deva",
+ "Croatian": "hrv_Latn",
+ "Hungarian": "hun_Latn",
+ "Armenian": "hye_Armn",
+ "Igbo": "ibo_Latn",
+ "Ilocano": "ilo_Latn",
+ "Indonesian": "ind_Latn",
+ "Icelandic": "isl_Latn",
+ "Italian": "ita_Latn",
+ "Javanese": "jav_Latn",
+ "Japanese": "jpn_Jpan",
+ "Kabyle": "kab_Latn",
+ "Jingpho": "kac_Latn",
+ "Kamba": "kam_Latn",
+ "Kannada": "kan_Knda",
+ "Kashmiri Arabic": "kas_Arab",
+ "Kashmiri Devanagari": "kas_Deva",
+ "Georgian": "kat_Geor",
+ "Central Kanuri Arabic": "knc_Arab",
+ "Central Kanuri Latin": "knc_Latn",
+ "Kazakh": "kaz_Cyrl",
+ "Kabiyรจ": "kbp_Latn",
+ "Kabuverdianu": "kea_Latn",
+ "Khmer": "khm_Khmr",
+ "Kikuyu": "kik_Latn",
+ "Kinyarwanda": "kin_Latn",
+ "Kyrgyz": "kir_Cyrl",
+ "Kimbundu": "kmb_Latn",
+ "Northern Kurdish": "kmr_Latn",
+ "Kikongo": "kon_Latn",
+ "Korean": "kor_Hang",
+ "Lao": "lao_Laoo",
+ "Ligurian": "lij_Latn",
+ "Limburgish": "lim_Latn",
+ "Lingala": "lin_Latn",
+ "Lithuanian": "lit_Latn",
+ "Lombard": "lmo_Latn",
+ "Latgalian": "ltg_Latn",
+ "Luxembourgish": "ltz_Latn",
+ "Luba-Kasai": "lua_Latn",
+ "Ganda": "lug_Latn",
+ "Luo": "luo_Latn",
+ "Mizo": "lus_Latn",
+ "Standard Latvian": "lvs_Latn",
+ "Magahi": "mag_Deva",
+ "Maithili": "mai_Deva",
+ "Malayalam": "mal_Mlym",
+ "Marathi": "mar_Deva",
+ "Minangkabau Arabic ": "min_Arab",
+ "Minangkabau Latin": "min_Latn",
+ "Macedonian": "mkd_Cyrl",
+ "Plateau Malagasy": "plt_Latn",
+ "Maltese": "mlt_Latn",
+ "Meitei Bengali": "mni_Beng",
+ "Halh Mongolian": "khk_Cyrl",
+ "Mossi": "mos_Latn",
+ "Maori": "mri_Latn",
+ "Burmese": "mya_Mymr",
+ "Dutch": "nld_Latn",
+ "Norwegian Nynorsk": "nno_Latn",
+ "Norwegian Bokmรฅl": "nob_Latn",
+ "Nepali": "npi_Deva",
+ "Northern Sotho": "nso_Latn",
+ "Nuer": "nus_Latn",
+ "Nyanja": "nya_Latn",
+ "Occitan": "oci_Latn",
+ "West Central Oromo": "gaz_Latn",
+ "Odia": "ory_Orya",
+ "Pangasinan": "pag_Latn",
+ "Eastern Panjabi": "pan_Guru",
+ "Papiamento": "pap_Latn",
+ "Western Persian": "pes_Arab",
+ "Polish": "pol_Latn",
+ "Portuguese": "por_Latn",
+ "Dari": "prs_Arab",
+ "Southern Pashto": "pbt_Arab",
+ "Ayacucho Quechua": "quy_Latn",
+ "Romanian": "ron_Latn",
+ "Rundi": "run_Latn",
+ "Russian": "rus_Cyrl",
+ "Sango": "sag_Latn",
+ "Sanskrit": "san_Deva",
+ "Santali": "sat_Olck",
+ "Sicilian": "scn_Latn",
+ "Shan": "shn_Mymr",
+ "Sinhala": "sin_Sinh",
+ "Slovak": "slk_Latn",
+ "Slovenian": "slv_Latn",
+ "Samoan": "smo_Latn",
+ "Shona": "sna_Latn",
+ "Sindhi": "snd_Arab",
+ "Somali": "som_Latn",
+ "Southern Sotho": "sot_Latn",
+ "Spanish": "spa_Latn",
+ "Tosk Albanian": "als_Latn",
+ "Sardinian": "srd_Latn",
+ "Serbian": "srp_Cyrl",
+ "Swati": "ssw_Latn",
+ "Sundanese": "sun_Latn",
+ "Swedish": "swe_Latn",
+ "Swahili": "swh_Latn",
+ "Silesian": "szl_Latn",
+ "Tamil": "tam_Taml",
+ "Tatar": "tat_Cyrl",
+ "Telugu": "tel_Telu",
+ "Tajik": "tgk_Cyrl",
+ "Tagalog": "tgl_Latn",
+ "Thai": "tha_Thai",
+ "Tigrinya": "tir_Ethi",
+ "Tamasheq Latin": "taq_Latn",
+ "Tamasheq Tifinagh": "taq_Tfng",
+ "Tok Pisin": "tpi_Latn",
+ "Tswana": "tsn_Latn",
+ "Tsonga": "tso_Latn",
+ "Turkmen": "tuk_Latn",
+ "Tumbuka": "tum_Latn",
+ "Turkish": "tur_Latn",
+ "Twi": "twi_Latn",
+ "Central Atlas Tamazight": "tzm_Tfng",
+ "Uyghur": "uig_Arab",
+ "Ukrainian": "ukr_Cyrl",
+ "Umbundu": "umb_Latn",
+ "Urdu": "urd_Arab",
+ "Northern Uzbek": "uzn_Latn",
+ "Venetian": "vec_Latn",
+ "Vietnamese": "vie_Latn",
+ "Waray": "war_Latn",
+ "Wolof": "wol_Latn",
+ "Xhosa": "xho_Latn",
+ "Eastern Yiddish": "ydd_Hebr",
+ "Yoruba": "yor_Latn",
+ "Yue Chinese": "yue_Hant",
+ "Chinese Simplified": "zho_Hans",
+ "Chinese Traditional": "zho_Hant",
+ "Standard Malay": "zsm_Latn",
+ "Zulu": "zul_Latn",
+}
+
+
+class TranslationTool(PipelineTool):
+ """
+ Example:
+
+ ```py
+ from transformers.agents import TranslationTool
+
+ translator = TranslationTool()
+ translator("This is a super nice API!", src_lang="English", tgt_lang="French")
+ ```
+ """
+
+ lang_to_code = LANGUAGE_CODES
+ default_checkpoint = "facebook/nllb-200-distilled-600M"
+ description = (
+ "This is a tool that translates text from a language to another."
+ f"Both `src_lang`and `tgt_lang` should belong to this list of languages: {list(lang_to_code.keys())}."
+ )
+ name = "translator"
+ pre_processor_class = AutoTokenizer
+ model_class = AutoModelForSeq2SeqLM
+
+ inputs = {
+ "text": {"type": "string", "description": "The text to translate"},
+ "src_lang": {
+ "type": "string",
+ "description": "The language of the text to translate. Written in plain English, such as 'Romanian', or 'Albanian'",
+ },
+ "tgt_lang": {
+ "type": "string",
+ "description": "The language for the desired ouput language. Written in plain English, such as 'Romanian', or 'Albanian'",
+ },
+ }
+ output_type = "string"
+
+ def encode(self, text, src_lang, tgt_lang):
+ if src_lang not in self.lang_to_code:
+ raise ValueError(f"{src_lang} is not a supported language.")
+ if tgt_lang not in self.lang_to_code:
+ raise ValueError(f"{tgt_lang} is not a supported language.")
+ src_lang = self.lang_to_code[src_lang]
+ tgt_lang = self.lang_to_code[tgt_lang]
+ return self.pre_processor._build_translation_inputs(
+ text, return_tensors="pt", src_lang=src_lang, tgt_lang=tgt_lang
+ )
+
+ def forward(self, inputs):
+ return self.model.generate(**inputs)
+
+ def decode(self, outputs):
+ return self.post_processor.decode(outputs[0].tolist(), skip_special_tokens=True)
diff --git a/docs/Makefile b/docs/Makefile
new file mode 100644
index 0000000..8879933
--- /dev/null
+++ b/docs/Makefile
@@ -0,0 +1,19 @@
+# Minimal makefile for Sphinx documentation
+#
+
+# You can set these variables from the command line.
+SPHINXOPTS =
+SPHINXBUILD = sphinx-build
+SOURCEDIR = source
+BUILDDIR = _build
+
+# Put it first so that "make" without argument is like "make help".
+help:
+ @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
+
+.PHONY: help Makefile
+
+# Catch-all target: route all unknown targets to Sphinx using the new
+# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
+%: Makefile
+ @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
\ No newline at end of file
diff --git a/docs/README.md b/docs/README.md
new file mode 100644
index 0000000..4c08929
--- /dev/null
+++ b/docs/README.md
@@ -0,0 +1,267 @@
+
+
+# Generating the documentation
+
+To generate the documentation, you first have to build it. Several packages are necessary to build the doc,
+you can install them with the following command, at the root of the code repository:
+
+```bash
+pip install -e ".[docs]"
+```
+
+Then you need to install our special tool that builds the documentation:
+
+```bash
+pip install git+https://github.com/huggingface/doc-builder
+```
+
+---
+**NOTE**
+
+You only need to generate the documentation to inspect it locally (if you're planning changes and want to
+check how they look before committing for instance). You don't have to commit the built documentation.
+
+---
+
+## Building the documentation
+
+Once you have setup the `doc-builder` and additional packages, you can generate the documentation by
+typing the following command:
+
+```bash
+doc-builder build accelerate docs/source/ --build_dir ~/tmp/test-build
+```
+
+You can adapt the `--build_dir` to set any temporary folder that you prefer. This command will create it and generate
+the MDX files that will be rendered as the documentation on the main website. You can inspect them in your favorite
+Markdown editor.
+
+## Previewing the documentation
+
+To preview the docs, first install the `watchdog` module with:
+
+```bash
+pip install watchdog
+```
+
+Then run the following command:
+
+```bash
+doc-builder preview {package_name} {path_to_docs}
+```
+
+For example:
+
+```bash
+doc-builder preview accelerate docs/source/
+```
+
+The docs will be viewable at [http://localhost:3000](http://localhost:3000). You can also preview the docs once you have opened a PR. You will see a bot add a comment to a link where the documentation with your changes lives.
+
+---
+**NOTE**
+
+The `preview` command only works with existing doc files. When you add a completely new file, you need to update `_toctree.yml` & restart `preview` command (`ctrl-c` to stop it & call `doc-builder preview ...` again).
+
+---
+
+## Adding a new element to the navigation bar
+
+Accepted files are Markdown (.md).
+
+Create a file with its extension and put it in the source directory. You can then link it to the toc-tree by putting
+the filename without the extension in the [`_toctree.yml`](https://github.com/huggingface/accelerate/blob/main/docs/source/_toctree.yml) file.
+
+## Renaming section headers and moving sections
+
+It helps to keep the old links working when renaming the section header and/or moving sections from one document to another. This is because the old links are likely to be used in Issues, Forums, and Social media and it'd make for a much more superior user experience if users reading those months later could still easily navigate to the originally intended information.
+
+Therefore, we simply keep a little map of moved sections at the end of the document where the original section was. The key is to preserve the original anchor.
+
+So if you renamed a section from: "Section A" to "Section B", then you can add at the end of the file:
+
+```
+Sections that were moved:
+
+[ Section A ]
+```
+and of course, if you moved it to another file, then:
+
+```
+Sections that were moved:
+
+[ Section A ]
+```
+
+Use the relative style to link to the new file so that the versioned docs continue to work.
+
+
+## Writing Documentation - Specification
+
+The `huggingface/accelerate` documentation follows the
+[Google documentation](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html) style for docstrings,
+although we can write them directly in Markdown.
+
+### Adding a new tutorial
+
+Adding a new tutorial or section is done in two steps:
+
+- Add a new file under `./source`. This file can either be ReStructuredText (.rst) or Markdown (.md).
+- Link that file in `./source/_toctree.yml` on the correct toc-tree.
+
+Make sure to put your new file under the proper section. It's unlikely to go in the first section (*Get Started*), so
+depending on the intended targets (beginners, more advanced users, or researchers) it should go in sections two, three, or
+four.
+
+### Writing source documentation
+
+Values that should be put in `code` should either be surrounded by backticks: \`like so\`. Note that argument names
+and objects like True, None, or any strings should usually be put in `code`.
+
+When mentioning a class, function, or method, it is recommended to use our syntax for internal links so that our tool
+adds a link to its documentation with this syntax: \[\`XXXClass\`\] or \[\`function\`\]. This requires the class or
+function to be in the main package.
+
+If you want to create a link to some internal class or function, you need to
+provide its path. For instance: \[\`utils.gather\`\]. This will be converted into a link with
+`utils.gather` in the description. To get rid of the path and only keep the name of the object you are
+linking to in the description, add a ~: \[\`~utils.gather\`\] will generate a link with `gather` in the description.
+
+The same works for methods so you can either use \[\`XXXClass.method\`\] or \[~\`XXXClass.method\`\].
+
+#### Defining arguments in a method
+
+Arguments should be defined with the `Args:` (or `Arguments:` or `Parameters:`) prefix, followed by a line return and
+an indentation. The argument should be followed by its type, with its shape if it is a tensor, a colon, and its
+description:
+
+```
+ Args:
+ n_layers (`int`): The number of layers of the model.
+```
+
+If the description is too long to fit in one line (more than 119 characters in total), another indentation is necessary
+before writing the description after the argument.
+
+Finally, to maintain uniformity if any *one* description is too long to fit on one line, the
+rest of the parameters should follow suit and have an indention before their description.
+
+Here's an example showcasing everything so far:
+
+```
+ Args:
+ gradient_accumulation_steps (`int`, *optional*, default to 1):
+ The number of steps that should pass before gradients are accumulated. A number > 1 should be combined with `Accelerator.accumulate`.
+ cpu (`bool`, *optional*):
+ Whether or not to force the script to execute on CPU. Will ignore GPU available if set to `True` and force the execution on one process only.
+```
+
+For optional arguments or arguments with defaults we follow the following syntax: imagine we have a function with the
+following signature:
+
+```
+def my_function(x: str = None, a: float = 1):
+```
+
+then its documentation should look like this:
+
+```
+ Args:
+ x (`str`, *optional*):
+ This argument controls ... and has a description longer than 119 chars.
+ a (`float`, *optional*, defaults to 1):
+ This argument is used to ... and has a description longer than 119 chars.
+```
+
+Note that we always omit the "defaults to \`None\`" when None is the default for any argument. Also note that even
+if the first line describing your argument type and its default gets long, you can't break it on several lines. You can
+however write as many lines as you want in the indented description (see the example above with `input_ids`).
+
+#### Writing a multi-line code block
+
+Multi-line code blocks can be useful for displaying examples. They are done between two lines of three backticks as usual in Markdown:
+
+
+````
+```python
+# first line of code
+# second line
+# etc
+```
+````
+
+#### Writing a return block
+
+The return block should be introduced with the `Returns:` prefix, followed by a line return and an indentation.
+The first line should be the type of the return, followed by a line return. No need to indent further for the elements
+building the return.
+
+Here's an example of a single value return:
+
+```
+ Returns:
+ `List[int]`: A list of integers in the range [0, 1] --- 1 for a special token, 0 for a sequence token.
+```
+
+Here's an example of a tuple return, comprising several objects:
+
+```
+ Returns:
+ `tuple(torch.FloatTensor)` comprising various elements depending on the configuration ([`BertConfig`]) and inputs:
+ - ** loss** (*optional*, returned when `masked_lm_labels` is provided) `torch.FloatTensor` of shape `(1,)` --
+ Total loss is the sum of the masked language modeling loss and the next sequence prediction (classification) loss.
+ - **prediction_scores** (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`) --
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+```
+
+## Styling the docstring
+
+We have an automatic script running with the `make style` comment that will make sure that:
+- the docstrings fully take advantage of the line width
+- all code examples are formatted using black, like the code of the Transformers library
+
+This script may have some weird failures if you made a syntax mistake or if you uncover a bug. Therefore, it's
+recommended to commit your changes before running `make style`, so you can revert the changes done by that script
+easily.
+
+## Writing documentation examples
+
+The syntax for Example docstrings can look as follows:
+
+```
+ Example:
+
+ ```python
+ >>> import time
+ >>> from accelerate import Accelerator
+ >>> accelerator = Accelerator()
+ >>> if accelerator.is_main_process:
+ ... time.sleep(2)
+ >>> else:
+ ... print("I'm waiting for the main process to finish its sleep...")
+ >>> accelerator.wait_for_everyone()
+ >>> # Should print on every process at the same time
+ >>> print("Everyone is here")
+ ```
+```
+
+The docstring should give a minimal, clear example of how the respective function
+is to be used in inference and also include the expected (ideally sensible)
+output.
+Often, readers will try out the example before even going through the function
+or class definitions. Therefore, it is of utmost importance that the example
+works as expected.
\ No newline at end of file
diff --git a/docs/source/_config.py b/docs/source/_config.py
new file mode 100644
index 0000000..f49e4e4
--- /dev/null
+++ b/docs/source/_config.py
@@ -0,0 +1,14 @@
+# docstyle-ignore
+INSTALL_CONTENT = """
+# Transformers installation
+! pip install transformers datasets evaluate accelerate
+# To install from source instead of the last release, comment the command above and uncomment the following one.
+# ! pip install git+https://github.com/huggingface/transformers.git
+"""
+
+notebook_first_cells = [{"type": "code", "content": INSTALL_CONTENT}]
+black_avoid_patterns = {
+ "{processor_class}": "FakeProcessorClass",
+ "{model_class}": "FakeModelClass",
+ "{object_class}": "FakeObjectClass",
+}
diff --git a/docs/source/_redirects.yml b/docs/source/_redirects.yml
new file mode 100644
index 0000000..ff70547
--- /dev/null
+++ b/docs/source/_redirects.yml
@@ -0,0 +1,5 @@
+# Optimizing inference
+
+perf_infer_gpu_many: perf_infer_gpu_one
+transformers_agents: agents
+quantization: quantization/overview
diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml
new file mode 100644
index 0000000..6e325e4
--- /dev/null
+++ b/docs/source/_toctree.yml
@@ -0,0 +1,984 @@
+- sections:
+ - local: index
+ title: ๐ค Transformers
+ - local: quicktour
+ title: Quick tour
+ - local: installation
+ title: Installation
+ - local: add_new_model
+ title: Adding a new model to `transformers`
+ title: Get started
+- sections:
+ - local: pipeline_tutorial
+ title: Run inference with pipelines
+ - local: autoclass_tutorial
+ title: Write portable code with AutoClass
+ - local: preprocessing
+ title: Preprocess data
+ - local: training
+ title: Fine-tune a pretrained model
+ - local: run_scripts
+ title: Train with a script
+ - local: accelerate
+ title: Set up distributed training with ๐ค Accelerate
+ - local: peft
+ title: Load and train adapters with ๐ค PEFT
+ - local: model_sharing
+ title: Share your model
+ - local: agents
+ title: Agents 101
+ - local: agents_advanced
+ title: Agents, supercharged - Multi-agents, External tools, and more
+ - local: llm_tutorial
+ title: Generation with LLMs
+ - local: conversations
+ title: Chatting with Transformers
+ title: Tutorials
+- sections:
+ - isExpanded: false
+ sections:
+ - local: tasks/sequence_classification
+ title: Text classification
+ - local: tasks/token_classification
+ title: Token classification
+ - local: tasks/question_answering
+ title: Question answering
+ - local: tasks/language_modeling
+ title: Causal language modeling
+ - local: tasks/masked_language_modeling
+ title: Masked language modeling
+ - local: tasks/translation
+ title: Translation
+ - local: tasks/summarization
+ title: Summarization
+ - local: tasks/multiple_choice
+ title: Multiple choice
+ title: Natural Language Processing
+ - isExpanded: false
+ sections:
+ - local: tasks/audio_classification
+ title: Audio classification
+ - local: tasks/asr
+ title: Automatic speech recognition
+ title: Audio
+ - isExpanded: false
+ sections:
+ - local: tasks/image_classification
+ title: Image classification
+ - local: tasks/semantic_segmentation
+ title: Image segmentation
+ - local: tasks/video_classification
+ title: Video classification
+ - local: tasks/object_detection
+ title: Object detection
+ - local: tasks/zero_shot_object_detection
+ title: Zero-shot object detection
+ - local: tasks/zero_shot_image_classification
+ title: Zero-shot image classification
+ - local: tasks/monocular_depth_estimation
+ title: Depth estimation
+ - local: tasks/image_to_image
+ title: Image-to-Image
+ - local: tasks/image_feature_extraction
+ title: Image Feature Extraction
+ - local: tasks/mask_generation
+ title: Mask Generation
+ - local: tasks/keypoint_detection
+ title: Keypoint Detection
+ - local: tasks/knowledge_distillation_for_image_classification
+ title: Knowledge Distillation for Computer Vision
+ title: Computer Vision
+ - isExpanded: false
+ sections:
+ - local: tasks/image_captioning
+ title: Image captioning
+ - local: tasks/document_question_answering
+ title: Document Question Answering
+ - local: tasks/visual_question_answering
+ title: Visual Question Answering
+ - local: tasks/text-to-speech
+ title: Text to speech
+ - local: tasks/image_text_to_text
+ title: Image-text-to-text
+ - local: tasks/video_text_to_text
+ title: Video-text-to-text
+ title: Multimodal
+ - isExpanded: false
+ sections:
+ - local: generation_strategies
+ title: Customize the generation strategy
+ - local: kv_cache
+ title: Best Practices for Generation with Cache
+ title: Generation
+ - isExpanded: false
+ sections:
+ - local: tasks/idefics
+ title: Image tasks with IDEFICS
+ - local: tasks/prompting
+ title: LLM prompting guide
+ title: Prompting
+ title: Task Guides
+- sections:
+ - local: fast_tokenizers
+ title: Use fast tokenizers from ๐ค Tokenizers
+ - local: multilingual
+ title: Run inference with multilingual models
+ - local: create_a_model
+ title: Use model-specific APIs
+ - local: custom_models
+ title: Share a custom model
+ - local: chat_templating
+ title: Chat templates
+ - local: trainer
+ title: Trainer
+ - local: sagemaker
+ title: Run training on Amazon SageMaker
+ - local: serialization
+ title: Export to ONNX
+ - local: tflite
+ title: Export to TFLite
+ - local: torchscript
+ title: Export to TorchScript
+ - local: benchmarks
+ title: Benchmarks
+ - local: notebooks
+ title: Notebooks with examples
+ - local: community
+ title: Community resources
+ - local: troubleshooting
+ title: Troubleshoot
+ - local: gguf
+ title: Interoperability with GGUF files
+ - local: tiktoken
+ title: Interoperability with TikToken files
+ - local: modular_transformers
+ title: Modularity in `transformers`
+ - local: how_to_hack_models
+ title: Model Hacking (overwriting a class to your usage)
+ title: Developer guides
+- sections:
+ - local: quantization/overview
+ title: Getting started
+ - local: quantization/bitsandbytes
+ title: bitsandbytes
+ - local: quantization/gptq
+ title: GPTQ
+ - local: quantization/awq
+ title: AWQ
+ - local: quantization/aqlm
+ title: AQLM
+ - local: quantization/quanto
+ title: Quanto
+ - local: quantization/eetq
+ title: EETQ
+ - local: quantization/hqq
+ title: HQQ
+ - local: quantization/fbgemm_fp8
+ title: FBGEMM_FP8
+ - local: quantization/optimum
+ title: Optimum
+ - local: quantization/torchao
+ title: TorchAO
+ - local: quantization/bitnet
+ title: BitNet
+ - local: quantization/compressed_tensors
+ title: compressed-tensors
+ - local: quantization/contribute
+ title: Contribute new quantization method
+ title: Quantization Methods
+- sections:
+ - local: performance
+ title: Overview
+ - local: llm_optims
+ title: LLM inference optimization
+ - sections:
+ - local: perf_train_gpu_one
+ title: Methods and tools for efficient training on a single GPU
+ - local: perf_train_gpu_many
+ title: Multiple GPUs and parallelism
+ - local: fsdp
+ title: Fully Sharded Data Parallel
+ - local: deepspeed
+ title: DeepSpeed
+ - local: perf_train_cpu
+ title: Efficient training on CPU
+ - local: perf_train_cpu_many
+ title: Distributed CPU training
+ - local: perf_train_tpu_tf
+ title: Training on TPU with TensorFlow
+ - local: perf_train_special
+ title: PyTorch training on Apple silicon
+ - local: perf_hardware
+ title: Custom hardware for training
+ - local: hpo_train
+ title: Hyperparameter Search using Trainer API
+ title: Efficient training techniques
+ - sections:
+ - local: perf_infer_cpu
+ title: CPU inference
+ - local: perf_infer_gpu_one
+ title: GPU inference
+ - local: perf_infer_gpu_multi
+ title: Multi-GPU inference
+ title: Optimizing inference
+ - local: big_models
+ title: Instantiate a big model
+ - local: debugging
+ title: Debugging
+ - local: tf_xla
+ title: XLA Integration for TensorFlow Models
+ - local: perf_torch_compile
+ title: Optimize inference using `torch.compile()`
+ title: Performance and scalability
+- sections:
+ - local: contributing
+ title: How to contribute to ๐ค Transformers?
+ - local: add_new_model
+ title: How to add a model to ๐ค Transformers?
+ - local: add_new_pipeline
+ title: How to add a pipeline to ๐ค Transformers?
+ - local: testing
+ title: Testing
+ - local: pr_checks
+ title: Checks on a Pull Request
+ title: Contribute
+- sections:
+ - local: philosophy
+ title: Philosophy
+ - local: glossary
+ title: Glossary
+ - local: task_summary
+ title: What ๐ค Transformers can do
+ - local: tasks_explained
+ title: How ๐ค Transformers solve tasks
+ - local: model_summary
+ title: The Transformer model family
+ - local: tokenizer_summary
+ title: Summary of the tokenizers
+ - local: attention
+ title: Attention mechanisms
+ - local: pad_truncation
+ title: Padding and truncation
+ - local: bertology
+ title: BERTology
+ - local: perplexity
+ title: Perplexity of fixed-length models
+ - local: pipeline_webserver
+ title: Pipelines for webserver inference
+ - local: model_memory_anatomy
+ title: Model training anatomy
+ - local: llm_tutorial_optimization
+ title: Getting the most out of LLMs
+ title: Conceptual guides
+- sections:
+ - sections:
+ - local: main_classes/agent
+ title: Agents and Tools
+ - local: model_doc/auto
+ title: Auto Classes
+ - local: main_classes/backbones
+ title: Backbones
+ - local: main_classes/callback
+ title: Callbacks
+ - local: main_classes/configuration
+ title: Configuration
+ - local: main_classes/data_collator
+ title: Data Collator
+ - local: main_classes/keras_callbacks
+ title: Keras callbacks
+ - local: main_classes/logging
+ title: Logging
+ - local: main_classes/model
+ title: Models
+ - local: main_classes/text_generation
+ title: Text Generation
+ - local: main_classes/onnx
+ title: ONNX
+ - local: main_classes/optimizer_schedules
+ title: Optimization
+ - local: main_classes/output
+ title: Model outputs
+ - local: main_classes/pipelines
+ title: Pipelines
+ - local: main_classes/processors
+ title: Processors
+ - local: main_classes/quantization
+ title: Quantization
+ - local: main_classes/tokenizer
+ title: Tokenizer
+ - local: main_classes/trainer
+ title: Trainer
+ - local: main_classes/deepspeed
+ title: DeepSpeed
+ - local: main_classes/executorch
+ title: ExecuTorch
+ - local: main_classes/feature_extractor
+ title: Feature Extractor
+ - local: main_classes/image_processor
+ title: Image Processor
+ title: Main Classes
+ - sections:
+ - isExpanded: false
+ sections:
+ - local: model_doc/albert
+ title: ALBERT
+ - local: model_doc/bart
+ title: BART
+ - local: model_doc/barthez
+ title: BARThez
+ - local: model_doc/bartpho
+ title: BARTpho
+ - local: model_doc/bert
+ title: BERT
+ - local: model_doc/bert-generation
+ title: BertGeneration
+ - local: model_doc/bert-japanese
+ title: BertJapanese
+ - local: model_doc/bertweet
+ title: Bertweet
+ - local: model_doc/big_bird
+ title: BigBird
+ - local: model_doc/bigbird_pegasus
+ title: BigBirdPegasus
+ - local: model_doc/biogpt
+ title: BioGpt
+ - local: model_doc/blenderbot
+ title: Blenderbot
+ - local: model_doc/blenderbot-small
+ title: Blenderbot Small
+ - local: model_doc/bloom
+ title: BLOOM
+ - local: model_doc/bort
+ title: BORT
+ - local: model_doc/byt5
+ title: ByT5
+ - local: model_doc/camembert
+ title: CamemBERT
+ - local: model_doc/canine
+ title: CANINE
+ - local: model_doc/codegen
+ title: CodeGen
+ - local: model_doc/code_llama
+ title: CodeLlama
+ - local: model_doc/cohere
+ title: Cohere
+ - local: model_doc/convbert
+ title: ConvBERT
+ - local: model_doc/cpm
+ title: CPM
+ - local: model_doc/cpmant
+ title: CPMANT
+ - local: model_doc/ctrl
+ title: CTRL
+ - local: model_doc/dbrx
+ title: DBRX
+ - local: model_doc/deberta
+ title: DeBERTa
+ - local: model_doc/deberta-v2
+ title: DeBERTa-v2
+ - local: model_doc/dialogpt
+ title: DialoGPT
+ - local: model_doc/distilbert
+ title: DistilBERT
+ - local: model_doc/dpr
+ title: DPR
+ - local: model_doc/electra
+ title: ELECTRA
+ - local: model_doc/encoder-decoder
+ title: Encoder Decoder Models
+ - local: model_doc/ernie
+ title: ERNIE
+ - local: model_doc/ernie_m
+ title: ErnieM
+ - local: model_doc/esm
+ title: ESM
+ - local: model_doc/falcon
+ title: Falcon
+ - local: model_doc/falcon_mamba
+ title: FalconMamba
+ - local: model_doc/fastspeech2_conformer
+ title: FastSpeech2Conformer
+ - local: model_doc/flan-t5
+ title: FLAN-T5
+ - local: model_doc/flan-ul2
+ title: FLAN-UL2
+ - local: model_doc/flaubert
+ title: FlauBERT
+ - local: model_doc/fnet
+ title: FNet
+ - local: model_doc/fsmt
+ title: FSMT
+ - local: model_doc/funnel
+ title: Funnel Transformer
+ - local: model_doc/fuyu
+ title: Fuyu
+ - local: model_doc/gemma
+ title: Gemma
+ - local: model_doc/gemma2
+ title: Gemma2
+ - local: model_doc/glm
+ title: GLM
+ - local: model_doc/openai-gpt
+ title: GPT
+ - local: model_doc/gpt_neo
+ title: GPT Neo
+ - local: model_doc/gpt_neox
+ title: GPT NeoX
+ - local: model_doc/gpt_neox_japanese
+ title: GPT NeoX Japanese
+ - local: model_doc/gptj
+ title: GPT-J
+ - local: model_doc/gpt2
+ title: GPT2
+ - local: model_doc/gpt_bigcode
+ title: GPTBigCode
+ - local: model_doc/gptsan-japanese
+ title: GPTSAN Japanese
+ - local: model_doc/gpt-sw3
+ title: GPTSw3
+ - local: model_doc/granite
+ title: Granite
+ - local: model_doc/granitemoe
+ title: GraniteMoe
+ - local: model_doc/herbert
+ title: HerBERT
+ - local: model_doc/ibert
+ title: I-BERT
+ - local: model_doc/jamba
+ title: Jamba
+ - local: model_doc/jetmoe
+ title: JetMoe
+ - local: model_doc/jukebox
+ title: Jukebox
+ - local: model_doc/led
+ title: LED
+ - local: model_doc/llama
+ title: LLaMA
+ - local: model_doc/llama2
+ title: Llama2
+ - local: model_doc/llama3
+ title: Llama3
+ - local: model_doc/longformer
+ title: Longformer
+ - local: model_doc/longt5
+ title: LongT5
+ - local: model_doc/luke
+ title: LUKE
+ - local: model_doc/m2m_100
+ title: M2M100
+ - local: model_doc/madlad-400
+ title: MADLAD-400
+ - local: model_doc/mamba
+ title: Mamba
+ - local: model_doc/mamba2
+ title: mamba2
+ - local: model_doc/marian
+ title: MarianMT
+ - local: model_doc/markuplm
+ title: MarkupLM
+ - local: model_doc/mbart
+ title: MBart and MBart-50
+ - local: model_doc/mega
+ title: MEGA
+ - local: model_doc/megatron-bert
+ title: MegatronBERT
+ - local: model_doc/megatron_gpt2
+ title: MegatronGPT2
+ - local: model_doc/mistral
+ title: Mistral
+ - local: model_doc/mixtral
+ title: Mixtral
+ - local: model_doc/mluke
+ title: mLUKE
+ - local: model_doc/mobilebert
+ title: MobileBERT
+ - local: model_doc/mpnet
+ title: MPNet
+ - local: model_doc/mpt
+ title: MPT
+ - local: model_doc/mra
+ title: MRA
+ - local: model_doc/mt5
+ title: MT5
+ - local: model_doc/mvp
+ title: MVP
+ - local: model_doc/myt5
+ title: myt5
+ - local: model_doc/nemotron
+ title: Nemotron
+ - local: model_doc/nezha
+ title: NEZHA
+ - local: model_doc/nllb
+ title: NLLB
+ - local: model_doc/nllb-moe
+ title: NLLB-MoE
+ - local: model_doc/nystromformer
+ title: Nystrรถmformer
+ - local: model_doc/olmo
+ title: OLMo
+ - local: model_doc/olmo2
+ title: OLMo2
+ - local: model_doc/olmoe
+ title: OLMoE
+ - local: model_doc/open-llama
+ title: Open-Llama
+ - local: model_doc/opt
+ title: OPT
+ - local: model_doc/pegasus
+ title: Pegasus
+ - local: model_doc/pegasus_x
+ title: PEGASUS-X
+ - local: model_doc/persimmon
+ title: Persimmon
+ - local: model_doc/phi
+ title: Phi
+ - local: model_doc/phi3
+ title: Phi-3
+ - local: model_doc/phimoe
+ title: PhiMoE
+ - local: model_doc/phobert
+ title: PhoBERT
+ - local: model_doc/plbart
+ title: PLBart
+ - local: model_doc/prophetnet
+ title: ProphetNet
+ - local: model_doc/qdqbert
+ title: QDQBert
+ - local: model_doc/qwen2
+ title: Qwen2
+ - local: model_doc/qwen2_moe
+ title: Qwen2MoE
+ - local: model_doc/rag
+ title: RAG
+ - local: model_doc/realm
+ title: REALM
+ - local: model_doc/recurrent_gemma
+ title: RecurrentGemma
+ - local: model_doc/reformer
+ title: Reformer
+ - local: model_doc/rembert
+ title: RemBERT
+ - local: model_doc/retribert
+ title: RetriBERT
+ - local: model_doc/roberta
+ title: RoBERTa
+ - local: model_doc/roberta-prelayernorm
+ title: RoBERTa-PreLayerNorm
+ - local: model_doc/roc_bert
+ title: RoCBert
+ - local: model_doc/roformer
+ title: RoFormer
+ - local: model_doc/rwkv
+ title: RWKV
+ - local: model_doc/splinter
+ title: Splinter
+ - local: model_doc/squeezebert
+ title: SqueezeBERT
+ - local: model_doc/stablelm
+ title: StableLm
+ - local: model_doc/starcoder2
+ title: Starcoder2
+ - local: model_doc/switch_transformers
+ title: SwitchTransformers
+ - local: model_doc/t5
+ title: T5
+ - local: model_doc/t5v1.1
+ title: T5v1.1
+ - local: model_doc/tapex
+ title: TAPEX
+ - local: model_doc/transfo-xl
+ title: Transformer XL
+ - local: model_doc/ul2
+ title: UL2
+ - local: model_doc/umt5
+ title: UMT5
+ - local: model_doc/xmod
+ title: X-MOD
+ - local: model_doc/xglm
+ title: XGLM
+ - local: model_doc/xlm
+ title: XLM
+ - local: model_doc/xlm-prophetnet
+ title: XLM-ProphetNet
+ - local: model_doc/xlm-roberta
+ title: XLM-RoBERTa
+ - local: model_doc/xlm-roberta-xl
+ title: XLM-RoBERTa-XL
+ - local: model_doc/xlm-v
+ title: XLM-V
+ - local: model_doc/xlnet
+ title: XLNet
+ - local: model_doc/yoso
+ title: YOSO
+ - local: model_doc/zamba
+ title: Zamba
+ title: Text models
+ - isExpanded: false
+ sections:
+ - local: model_doc/beit
+ title: BEiT
+ - local: model_doc/bit
+ title: BiT
+ - local: model_doc/conditional_detr
+ title: Conditional DETR
+ - local: model_doc/convnext
+ title: ConvNeXT
+ - local: model_doc/convnextv2
+ title: ConvNeXTV2
+ - local: model_doc/cvt
+ title: CvT
+ - local: model_doc/deformable_detr
+ title: Deformable DETR
+ - local: model_doc/deit
+ title: DeiT
+ - local: model_doc/depth_anything
+ title: Depth Anything
+ - local: model_doc/depth_anything_v2
+ title: Depth Anything V2
+ - local: model_doc/deta
+ title: DETA
+ - local: model_doc/detr
+ title: DETR
+ - local: model_doc/dinat
+ title: DiNAT
+ - local: model_doc/dinov2
+ title: DINOV2
+ - local: model_doc/dit
+ title: DiT
+ - local: model_doc/dpt
+ title: DPT
+ - local: model_doc/efficientformer
+ title: EfficientFormer
+ - local: model_doc/efficientnet
+ title: EfficientNet
+ - local: model_doc/focalnet
+ title: FocalNet
+ - local: model_doc/glpn
+ title: GLPN
+ - local: model_doc/hiera
+ title: Hiera
+ - local: model_doc/ijepa
+ title: I-JEPA
+ - local: model_doc/imagegpt
+ title: ImageGPT
+ - local: model_doc/levit
+ title: LeViT
+ - local: model_doc/mask2former
+ title: Mask2Former
+ - local: model_doc/maskformer
+ title: MaskFormer
+ - local: model_doc/mobilenet_v1
+ title: MobileNetV1
+ - local: model_doc/mobilenet_v2
+ title: MobileNetV2
+ - local: model_doc/mobilevit
+ title: MobileViT
+ - local: model_doc/mobilevitv2
+ title: MobileViTV2
+ - local: model_doc/nat
+ title: NAT
+ - local: model_doc/poolformer
+ title: PoolFormer
+ - local: model_doc/pvt
+ title: Pyramid Vision Transformer (PVT)
+ - local: model_doc/pvt_v2
+ title: Pyramid Vision Transformer v2 (PVTv2)
+ - local: model_doc/regnet
+ title: RegNet
+ - local: model_doc/resnet
+ title: ResNet
+ - local: model_doc/rt_detr
+ title: RT-DETR
+ - local: model_doc/segformer
+ title: SegFormer
+ - local: model_doc/seggpt
+ title: SegGpt
+ - local: model_doc/superpoint
+ title: SuperPoint
+ - local: model_doc/swiftformer
+ title: SwiftFormer
+ - local: model_doc/swin
+ title: Swin Transformer
+ - local: model_doc/swinv2
+ title: Swin Transformer V2
+ - local: model_doc/swin2sr
+ title: Swin2SR
+ - local: model_doc/table-transformer
+ title: Table Transformer
+ - local: model_doc/upernet
+ title: UperNet
+ - local: model_doc/van
+ title: VAN
+ - local: model_doc/vit
+ title: Vision Transformer (ViT)
+ - local: model_doc/vit_hybrid
+ title: ViT Hybrid
+ - local: model_doc/vitdet
+ title: ViTDet
+ - local: model_doc/vit_mae
+ title: ViTMAE
+ - local: model_doc/vitmatte
+ title: ViTMatte
+ - local: model_doc/vit_msn
+ title: ViTMSN
+ - local: model_doc/yolos
+ title: YOLOS
+ - local: model_doc/zoedepth
+ title: ZoeDepth
+ title: Vision models
+ - isExpanded: false
+ sections:
+ - local: model_doc/audio-spectrogram-transformer
+ title: Audio Spectrogram Transformer
+ - local: model_doc/bark
+ title: Bark
+ - local: model_doc/clap
+ title: CLAP
+ - local: model_doc/dac
+ title: dac
+ - local: model_doc/encodec
+ title: EnCodec
+ - local: model_doc/hiera
+ title: Hiera
+ - local: model_doc/hubert
+ title: Hubert
+ - local: model_doc/mctct
+ title: MCTCT
+ - local: model_doc/mimi
+ title: Mimi
+ - local: model_doc/mms
+ title: MMS
+ - local: model_doc/moshi
+ title: Moshi
+ - local: model_doc/musicgen
+ title: MusicGen
+ - local: model_doc/musicgen_melody
+ title: MusicGen Melody
+ - local: model_doc/pop2piano
+ title: Pop2Piano
+ - local: model_doc/seamless_m4t
+ title: Seamless-M4T
+ - local: model_doc/seamless_m4t_v2
+ title: SeamlessM4T-v2
+ - local: model_doc/sew
+ title: SEW
+ - local: model_doc/sew-d
+ title: SEW-D
+ - local: model_doc/speech_to_text
+ title: Speech2Text
+ - local: model_doc/speech_to_text_2
+ title: Speech2Text2
+ - local: model_doc/speecht5
+ title: SpeechT5
+ - local: model_doc/unispeech
+ title: UniSpeech
+ - local: model_doc/unispeech-sat
+ title: UniSpeech-SAT
+ - local: model_doc/univnet
+ title: UnivNet
+ - local: model_doc/vits
+ title: VITS
+ - local: model_doc/wav2vec2
+ title: Wav2Vec2
+ - local: model_doc/wav2vec2-bert
+ title: Wav2Vec2-BERT
+ - local: model_doc/wav2vec2-conformer
+ title: Wav2Vec2-Conformer
+ - local: model_doc/wav2vec2_phoneme
+ title: Wav2Vec2Phoneme
+ - local: model_doc/wavlm
+ title: WavLM
+ - local: model_doc/whisper
+ title: Whisper
+ - local: model_doc/xls_r
+ title: XLS-R
+ - local: model_doc/xlsr_wav2vec2
+ title: XLSR-Wav2Vec2
+ title: Audio models
+ - isExpanded: false
+ sections:
+ - local: model_doc/timesformer
+ title: TimeSformer
+ - local: model_doc/videomae
+ title: VideoMAE
+ - local: model_doc/vivit
+ title: ViViT
+ title: Video models
+ - isExpanded: false
+ sections:
+ - local: model_doc/align
+ title: ALIGN
+ - local: model_doc/altclip
+ title: AltCLIP
+ - local: model_doc/aria
+ title: Aria
+ - local: model_doc/blip
+ title: BLIP
+ - local: model_doc/blip-2
+ title: BLIP-2
+ - local: model_doc/bridgetower
+ title: BridgeTower
+ - local: model_doc/bros
+ title: BROS
+ - local: model_doc/chameleon
+ title: Chameleon
+ - local: model_doc/chinese_clip
+ title: Chinese-CLIP
+ - local: model_doc/clip
+ title: CLIP
+ - local: model_doc/clipseg
+ title: CLIPSeg
+ - local: model_doc/clvp
+ title: CLVP
+ - local: model_doc/data2vec
+ title: Data2Vec
+ - local: model_doc/deplot
+ title: DePlot
+ - local: model_doc/donut
+ title: Donut
+ - local: model_doc/flava
+ title: FLAVA
+ - local: model_doc/git
+ title: GIT
+ - local: model_doc/grounding-dino
+ title: Grounding DINO
+ - local: model_doc/groupvit
+ title: GroupViT
+ - local: model_doc/idefics
+ title: IDEFICS
+ - local: model_doc/idefics2
+ title: Idefics2
+ - local: model_doc/idefics3
+ title: Idefics3
+ - local: model_doc/instructblip
+ title: InstructBLIP
+ - local: model_doc/instructblipvideo
+ title: InstructBlipVideo
+ - local: model_doc/kosmos-2
+ title: KOSMOS-2
+ - local: model_doc/layoutlm
+ title: LayoutLM
+ - local: model_doc/layoutlmv2
+ title: LayoutLMV2
+ - local: model_doc/layoutlmv3
+ title: LayoutLMV3
+ - local: model_doc/layoutxlm
+ title: LayoutXLM
+ - local: model_doc/lilt
+ title: LiLT
+ - local: model_doc/llava
+ title: Llava
+ - local: model_doc/llava_next
+ title: LLaVA-NeXT
+ - local: model_doc/llava_next_video
+ title: LLaVa-NeXT-Video
+ - local: model_doc/llava_onevision
+ title: LLaVA-Onevision
+ - local: model_doc/lxmert
+ title: LXMERT
+ - local: model_doc/matcha
+ title: MatCha
+ - local: model_doc/mgp-str
+ title: MGP-STR
+ - local: model_doc/mllama
+ title: mllama
+ - local: model_doc/nougat
+ title: Nougat
+ - local: model_doc/omdet-turbo
+ title: OmDet-Turbo
+ - local: model_doc/oneformer
+ title: OneFormer
+ - local: model_doc/owlvit
+ title: OWL-ViT
+ - local: model_doc/owlv2
+ title: OWLv2
+ - local: model_doc/paligemma
+ title: PaliGemma
+ - local: model_doc/perceiver
+ title: Perceiver
+ - local: model_doc/pix2struct
+ title: Pix2Struct
+ - local: model_doc/pixtral
+ title: Pixtral
+ - local: model_doc/qwen2_audio
+ title: Qwen2Audio
+ - local: model_doc/qwen2_vl
+ title: Qwen2VL
+ - local: model_doc/sam
+ title: Segment Anything
+ - local: model_doc/siglip
+ title: SigLIP
+ - local: model_doc/speech-encoder-decoder
+ title: Speech Encoder Decoder Models
+ - local: model_doc/tapas
+ title: TAPAS
+ - local: model_doc/trocr
+ title: TrOCR
+ - local: model_doc/tvlt
+ title: TVLT
+ - local: model_doc/tvp
+ title: TVP
+ - local: model_doc/udop
+ title: UDOP
+ - local: model_doc/video_llava
+ title: VideoLlava
+ - local: model_doc/vilt
+ title: ViLT
+ - local: model_doc/vipllava
+ title: VipLlava
+ - local: model_doc/vision-encoder-decoder
+ title: Vision Encoder Decoder Models
+ - local: model_doc/vision-text-dual-encoder
+ title: Vision Text Dual Encoder
+ - local: model_doc/visual_bert
+ title: VisualBERT
+ - local: model_doc/xclip
+ title: X-CLIP
+ title: Multimodal models
+ - isExpanded: false
+ sections:
+ - local: model_doc/decision_transformer
+ title: Decision Transformer
+ - local: model_doc/trajectory_transformer
+ title: Trajectory Transformer
+ title: Reinforcement learning models
+ - isExpanded: false
+ sections:
+ - local: model_doc/autoformer
+ title: Autoformer
+ - local: model_doc/informer
+ title: Informer
+ - local: model_doc/patchtsmixer
+ title: PatchTSMixer
+ - local: model_doc/patchtst
+ title: PatchTST
+ - local: model_doc/time_series_transformer
+ title: Time Series Transformer
+ title: Time series models
+ - isExpanded: false
+ sections:
+ - local: model_doc/graphormer
+ title: Graphormer
+ title: Graph models
+ title: Models
+ - sections:
+ - local: internal/modeling_utils
+ title: Custom Layers and Utilities
+ - local: internal/pipelines_utils
+ title: Utilities for pipelines
+ - local: internal/tokenization_utils
+ title: Utilities for Tokenizers
+ - local: internal/trainer_utils
+ title: Utilities for Trainer
+ - local: internal/generation_utils
+ title: Utilities for Generation
+ - local: internal/image_processing_utils
+ title: Utilities for Image Processors
+ - local: internal/audio_utils
+ title: Utilities for Audio processing
+ - local: internal/file_utils
+ title: General Utilities
+ - local: internal/time_series_utils
+ title: Utilities for Time Series
+ title: Internal Helpers
+ title: API
diff --git a/docs/source/agents.md b/docs/source/agents.md
new file mode 100644
index 0000000..56c9184
--- /dev/null
+++ b/docs/source/agents.md
@@ -0,0 +1,431 @@
+
+# Agents and tools
+
+[[open-in-colab]]
+
+### What is an agent?
+
+Large Language Models (LLMs) trained to perform [causal language modeling](./tasks/language_modeling) can tackle a wide range of tasks, but they often struggle with basic tasks like logic, calculation, and search. When prompted in domains in which they do not perform well, they often fail to generate the answer we expect them to.
+
+One approach to overcome this weakness is to create an *agent*.
+
+An agent is a system that uses an LLM as its engine, and it has access to functions called *tools*.
+
+These *tools* are functions for performing a task, and they contain all necessary description for the agent to properly use them.
+
+The agent can be programmed to:
+- devise a series of actions/tools and run them all at once, like the [`CodeAgent`]
+- plan and execute actions/tools one by one and wait for the outcome of each action before launching the next one, like the [`ReactJsonAgent`]
+
+### Types of agents
+
+#### Code agent
+
+This agent has a planning step, then generates python code to execute all its actions at once. It natively handles different input and output types for its tools, thus it is the recommended choice for multimodal tasks.
+
+#### React agents
+
+This is the go-to agent to solve reasoning tasks, since the ReAct framework ([Yao et al., 2022](https://huggingface.co/papers/2210.03629)) makes it really efficient to think on the basis of its previous observations.
+
+We implement two versions of ReactJsonAgent:
+- [`ReactJsonAgent`] generates tool calls as a JSON in its output.
+- [`ReactCodeAgent`] is a new type of ReactJsonAgent that generates its tool calls as blobs of code, which works really well for LLMs that have strong coding performance.
+
+> [!TIP]
+> Read [Open-source LLMs as LangChain Agents](https://huggingface.co/blog/open-source-llms-as-agents) blog post to learn more about ReAct agents.
+
+
+
+
+
+
+
+
+For example, here is how a ReAct Code agent would work its way through the following question.
+
+```py3
+>>> agent.run(
+... "How many more blocks (also denoted as layers) in BERT base encoder than the encoder from the architecture proposed in Attention is All You Need?",
+... )
+=====New task=====
+How many more blocks (also denoted as layers) in BERT base encoder than the encoder from the architecture proposed in Attention is All You Need?
+====Agent is executing the code below:
+bert_blocks = search(query="number of blocks in BERT base encoder")
+print("BERT blocks:", bert_blocks)
+====
+Print outputs:
+BERT blocks: twelve encoder blocks
+
+====Agent is executing the code below:
+attention_layer = search(query="number of layers in Attention is All You Need")
+print("Attention layers:", attention_layer)
+====
+Print outputs:
+Attention layers: Encoder: The encoder is composed of a stack of N = 6 identical layers. Each layer has two sub-layers. The first is a multi-head self-attention mechanism, and the second is a simple, position- 2 Page 3 Figure 1: The Transformer - model architecture.
+
+====Agent is executing the code below:
+bert_blocks = 12
+attention_layers = 6
+diff = bert_blocks - attention_layers
+print("Difference in blocks:", diff)
+final_answer(diff)
+====
+
+Print outputs:
+Difference in blocks: 6
+
+Final answer: 6
+```
+
+### How can I build an agent?
+
+To initialize an agent, you need these arguments:
+
+- an LLM to power your agent - the agent is not exactly the LLM, itโs more like the agent is a program that uses an LLM as its engine.
+- a system prompt: what the LLM engine will be prompted with to generate its output
+- a toolbox from which the agent pick tools to execute
+- a parser to extract from the LLM output which tools are to call and with which arguments
+
+Upon initialization of the agent system, the tool attributes are used to generate a tool description, then baked into the agentโs `system_prompt` to let it know which tools it can use and why.
+
+To start with, please install the `agents` extras in order to install all default dependencies.
+
+```bash
+pip install transformers[agents]
+```
+
+Build your LLM engine by defining a `llm_engine` method which accepts a list of [messages](./chat_templating) and returns text. This callable also needs to accept a `stop` argument that indicates when to stop generating.
+
+```python
+from huggingface_hub import login, InferenceClient
+
+login("")
+
+client = InferenceClient(model="meta-llama/Meta-Llama-3-70B-Instruct")
+
+def llm_engine(messages, stop_sequences=["Task"]) -> str:
+ response = client.chat_completion(messages, stop=stop_sequences, max_tokens=1000)
+ answer = response.choices[0].message.content
+ return answer
+```
+
+You could use any `llm_engine` method as long as:
+1. it follows the [messages format](./chat_templating) (`List[Dict[str, str]]`) for its input `messages`, and it returns a `str`.
+2. it stops generating outputs at the sequences passed in the argument `stop_sequences`
+
+Additionally, `llm_engine` can also take a `grammar` argument. In the case where you specify a `grammar` upon agent initialization, this argument will be passed to the calls to llm_engine, with the `grammar` that you defined upon initialization, to allow [constrained generation](https://huggingface.co/docs/text-generation-inference/conceptual/guidance) in order to force properly-formatted agent outputs.
+
+You will also need a `tools` argument which accepts a list of `Tools` - it can be an empty list. You can also add the default toolbox on top of your `tools` list by defining the optional argument `add_base_tools=True`.
+
+Now you can create an agent, like [`CodeAgent`], and run it. You can also create a [`TransformersEngine`] with a pre-initialized pipeline to run inference on your local machine using `transformers`.
+For convenience, since agentic behaviours generally require stronger models such as `Llama-3.1-70B-Instruct` that are harder to run locally for now, we also provide the [`HfApiEngine`] class that initializes a `huggingface_hub.InferenceClient` under the hood.
+
+```python
+from transformers import CodeAgent, HfApiEngine
+
+llm_engine = HfApiEngine(model="meta-llama/Meta-Llama-3-70B-Instruct")
+agent = CodeAgent(tools=[], llm_engine=llm_engine, add_base_tools=True)
+
+agent.run(
+ "Could you translate this sentence from French, say it out loud and return the audio.",
+ sentence="Oรน est la boulangerie la plus proche?",
+)
+```
+
+This will be handy in case of emergency baguette need!
+You can even leave the argument `llm_engine` undefined, and an [`HfApiEngine`] will be created by default.
+
+```python
+from transformers import CodeAgent
+
+agent = CodeAgent(tools=[], add_base_tools=True)
+
+agent.run(
+ "Could you translate this sentence from French, say it out loud and give me the audio.",
+ sentence="Oรน est la boulangerie la plus proche?",
+)
+```
+
+Note that we used an additional `sentence` argument: you can pass text as additional arguments to the model.
+
+You can also use this to indicate the path to local or remote files for the model to use:
+
+```py
+from transformers import ReactCodeAgent
+
+agent = ReactCodeAgent(tools=[], llm_engine=llm_engine, add_base_tools=True)
+
+agent.run("Why does Mike not know many people in New York?", audio="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/recording.mp3")
+```
+
+
+The prompt and output parser were automatically defined, but you can easily inspect them by calling the `system_prompt_template` on your agent.
+
+```python
+print(agent.system_prompt_template)
+```
+
+It's important to explain as clearly as possible the task you want to perform.
+Every [`~Agent.run`] operation is independent, and since an agent is powered by an LLM, minor variations in your prompt might yield completely different results.
+You can also run an agent consecutively for different tasks: each time the attributes `agent.task` and `agent.logs` will be re-initialized.
+
+
+#### Code execution
+
+A Python interpreter executes the code on a set of inputs passed along with your tools.
+This should be safe because the only functions that can be called are the tools you provided (especially if it's only tools by Hugging Face) and the print function, so you're already limited in what can be executed.
+
+The Python interpreter also doesn't allow imports by default outside of a safe list, so all the most obvious attacks shouldn't be an issue.
+You can still authorize additional imports by passing the authorized modules as a list of strings in argument `additional_authorized_imports` upon initialization of your [`ReactCodeAgent`] or [`CodeAgent`]:
+
+```py
+>>> from transformers import ReactCodeAgent
+
+>>> agent = ReactCodeAgent(tools=[], additional_authorized_imports=['requests', 'bs4'])
+>>> agent.run("Could you get me the title of the page at url 'https://huggingface.co/blog'?")
+
+(...)
+'Hugging Face โ Blog'
+```
+
+The execution will stop at any code trying to perform an illegal operation or if there is a regular Python error with the code generated by the agent.
+
+> [!WARNING]
+> The LLM can generate arbitrary code that will then be executed: do not add any unsafe imports!
+
+### The system prompt
+
+An agent, or rather the LLM that drives the agent, generates an output based on the system prompt. The system prompt can be customized and tailored to the intended task. For example, check the system prompt for the [`ReactCodeAgent`] (below version is slightly simplified).
+
+```text
+You will be given a task to solve as best you can.
+You have access to the following tools:
+<>
+
+To solve the task, you must plan forward to proceed in a series of steps, in a cycle of 'Thought:', 'Code:', and 'Observation:' sequences.
+
+At each step, in the 'Thought:' sequence, you should first explain your reasoning towards solving the task, then the tools that you want to use.
+Then in the 'Code:' sequence, you should write the code in simple Python. The code sequence must end with '/End code' sequence.
+During each intermediate step, you can use 'print()' to save whatever important information you will then need.
+These print outputs will then be available in the 'Observation:' field, for using this information as input for the next step.
+
+In the end you have to return a final answer using the `final_answer` tool.
+
+Here are a few examples using notional tools:
+---
+{examples}
+
+Above example were using notional tools that might not exist for you. You only have acces to those tools:
+<>
+You also can perform computations in the python code you generate.
+
+Always provide a 'Thought:' and a 'Code:\n```py' sequence ending with '```' sequence. You MUST provide at least the 'Code:' sequence to move forward.
+
+Remember to not perform too many operations in a single code block! You should split the task into intermediate code blocks.
+Print results at the end of each step to save the intermediate results. Then use final_answer() to return the final result.
+
+Remember to make sure that variables you use are all defined.
+
+Now Begin!
+```
+
+The system prompt includes:
+- An *introduction* that explains how the agent should behave and what tools are.
+- A description of all the tools that is defined by a `<>` token that is dynamically replaced at runtime with the tools defined/chosen by the user.
+ - The tool description comes from the tool attributes, `name`, `description`, `inputs` and `output_type`, and a simple `jinja2` template that you can refine.
+- The expected output format.
+
+You could improve the system prompt, for example, by adding an explanation of the output format.
+
+For maximum flexibility, you can overwrite the whole system prompt template by passing your custom prompt as an argument to the `system_prompt` parameter.
+
+```python
+from transformers import ReactJsonAgent
+from transformers.agents import PythonInterpreterTool
+
+agent = ReactJsonAgent(tools=[PythonInterpreterTool()], system_prompt="{your_custom_prompt}")
+```
+
+> [!WARNING]
+> Please make sure to define the `<>` string somewhere in the `template` so the agent is aware
+of the available tools.
+
+
+### Inspecting an agent run
+
+Here are a few useful attributes to inspect what happened after a run:
+- `agent.logs` stores the fine-grained logs of the agent. At every step of the agent's run, everything gets stored in a dictionary that then is appended to `agent.logs`.
+- Running `agent.write_inner_memory_from_logs()` creates an inner memory of the agent's logs for the LLM to view, as a list of chat messages. This method goes over each step of the log and only stores what it's interested in as a message: for instance, it will save the system prompt and task in separate messages, then for each step it will store the LLM output as a message, and the tool call output as another message. Use this if you want a higher-level view of what has happened - but not every log will be transcripted by this method.
+
+## Tools
+
+A tool is an atomic function to be used by an agent.
+
+You can for instance check the [`PythonInterpreterTool`]: it has a name, a description, input descriptions, an output type, and a `__call__` method to perform the action.
+
+When the agent is initialized, the tool attributes are used to generate a tool description which is baked into the agent's system prompt. This lets the agent know which tools it can use and why.
+
+### Default toolbox
+
+Transformers comes with a default toolbox for empowering agents, that you can add to your agent upon initialization with argument `add_base_tools = True`:
+
+- **Document question answering**: given a document (such as a PDF) in image format, answer a question on this document ([Donut](./model_doc/donut))
+- **Image question answering**: given an image, answer a question on this image ([VILT](./model_doc/vilt))
+- **Speech to text**: given an audio recording of a person talking, transcribe the speech into text ([Whisper](./model_doc/whisper))
+- **Text to speech**: convert text to speech ([SpeechT5](./model_doc/speecht5))
+- **Translation**: translates a given sentence from source language to target language.
+- **DuckDuckGo search***: performs a web search using DuckDuckGo browser.
+- **Python code interpreter**: runs your the LLM generated Python code in a secure environment. This tool will only be added to [`ReactJsonAgent`] if you initialize it with `add_base_tools=True`, since code-based agent can already natively execute Python code
+
+
+You can manually use a tool by calling the [`load_tool`] function and a task to perform.
+
+
+```python
+from transformers import load_tool
+
+tool = load_tool("text-to-speech")
+audio = tool("This is a text to speech tool")
+```
+
+
+### Create a new tool
+
+You can create your own tool for use cases not covered by the default tools from Hugging Face.
+For example, let's create a tool that returns the most downloaded model for a given task from the Hub.
+
+You'll start with the code below.
+
+```python
+from huggingface_hub import list_models
+
+task = "text-classification"
+
+model = next(iter(list_models(filter=task, sort="downloads", direction=-1)))
+print(model.id)
+```
+
+This code can quickly be converted into a tool, just by wrapping it in a function and adding the `tool` decorator:
+
+
+```py
+from transformers import tool
+
+@tool
+def model_download_tool(task: str) -> str:
+ """
+ This is a tool that returns the most downloaded model of a given task on the Hugging Face Hub.
+ It returns the name of the checkpoint.
+
+ Args:
+ task: The task for which
+ """
+ model = next(iter(list_models(filter="text-classification", sort="downloads", direction=-1)))
+ return model.id
+```
+
+The function needs:
+- A clear name. The name usually describes what the tool does. Since the code returns the model with the most downloads for a task, let's put `model_download_tool`.
+- Type hints on both inputs and output
+- A description, that includes an 'Args:' part where each argument is described (without a type indication this time, it will be pulled from the type hint).
+All these will be automatically baked into the agent's system prompt upon initialization: so strive to make them as clear as possible!
+
+> [!TIP]
+> This definition format is the same as tool schemas used in `apply_chat_template`, the only difference is the added `tool` decorator: read more on our tool use API [here](https://huggingface.co/blog/unified-tool-use#passing-tools-to-a-chat-template).
+
+Then you can directly initialize your agent:
+```py
+from transformers import CodeAgent
+agent = CodeAgent(tools=[model_download_tool], llm_engine=llm_engine)
+agent.run(
+ "Can you give me the name of the model that has the most downloads in the 'text-to-video' task on the Hugging Face Hub?"
+)
+```
+
+You get the following:
+```text
+======== New task ========
+Can you give me the name of the model that has the most downloads in the 'text-to-video' task on the Hugging Face Hub?
+==== Agent is executing the code below:
+most_downloaded_model = model_download_tool(task="text-to-video")
+print(f"The most downloaded model for the 'text-to-video' task is {most_downloaded_model}.")
+====
+```
+
+And the output:
+`"The most downloaded model for the 'text-to-video' task is ByteDance/AnimateDiff-Lightning."`
+
+### Manage your agent's toolbox
+
+If you have already initialized an agent, it is inconvenient to reinitialize it from scratch with a tool you want to use. With Transformers, you can manage an agent's toolbox by adding or replacing a tool.
+
+Let's add the `model_download_tool` to an existing agent initialized with only the default toolbox.
+
+```python
+from transformers import CodeAgent
+
+agent = CodeAgent(tools=[], llm_engine=llm_engine, add_base_tools=True)
+agent.toolbox.add_tool(model_download_tool)
+```
+Now we can leverage both the new tool and the previous text-to-speech tool:
+
+```python
+agent.run(
+ "Can you read out loud the name of the model that has the most downloads in the 'text-to-video' task on the Hugging Face Hub and return the audio?"
+)
+```
+
+
+| **Audio** |
+|------------------------------------------------------------------------------------------------------------------------------------------------------|
+| |
+
+
+> [!WARNING]
+> Beware when adding tools to an agent that already works well because it can bias selection towards your tool or select another tool other than the one already defined.
+
+
+Use the `agent.toolbox.update_tool()` method to replace an existing tool in the agent's toolbox.
+This is useful if your new tool is a one-to-one replacement of the existing tool because the agent already knows how to perform that specific task.
+Just make sure the new tool follows the same API as the replaced tool or adapt the system prompt template to ensure all examples using the replaced tool are updated.
+
+
+### Use a collection of tools
+
+You can leverage tool collections by using the ToolCollection object, with the slug of the collection you want to use.
+Then pass them as a list to initialize you agent, and start using them!
+
+```py
+from transformers import ToolCollection, ReactCodeAgent
+
+image_tool_collection = ToolCollection(collection_slug="huggingface-tools/diffusion-tools-6630bb19a942c2306a2cdb6f")
+agent = ReactCodeAgent(tools=[*image_tool_collection.tools], add_base_tools=True)
+
+agent.run("Please draw me a picture of rivers and lakes.")
+```
+
+To speed up the start, tools are loaded only if called by the agent.
+
+This gets you this image:
+
+
diff --git a/docs/source/agents_advanced.md b/docs/source/agents_advanced.md
new file mode 100644
index 0000000..c4753bf
--- /dev/null
+++ b/docs/source/agents_advanced.md
@@ -0,0 +1,261 @@
+
+# Agents, supercharged - Multi-agents, External tools, and more
+
+[[open-in-colab]]
+
+### What is an agent?
+
+> [!TIP]
+> If you're new to `transformers.agents`, make sure to first read the main [agents documentation](./agents).
+
+In this page we're going to highlight several advanced uses of `transformers.agents`.
+
+## Multi-agents
+
+Multi-agent has been introduced in Microsoft's framework [Autogen](https://huggingface.co/papers/2308.08155).
+It simply means having several agents working together to solve your task instead of only one.
+It empirically yields better performance on most benchmarks. The reason for this better performance is conceptually simple: for many tasks, rather than using a do-it-all system, you would prefer to specialize units on sub-tasks. Here, having agents with separate tool sets and memories allows to achieve efficient specialization.
+
+You can easily build hierarchical multi-agent systems with `transformers.agents`.
+
+To do so, encapsulate the agent in a [`ManagedAgent`] object. This object needs arguments `agent`, `name`, and a `description`, which will then be embedded in the manager agent's system prompt to let it know how to call this managed agent, as we also do for tools.
+
+Here's an example of making an agent that managed a specific web search agent using our [`DuckDuckGoSearchTool`]:
+
+```py
+from transformers.agents import ReactCodeAgent, HfApiEngine, DuckDuckGoSearchTool, ManagedAgent
+
+llm_engine = HfApiEngine()
+
+web_agent = ReactCodeAgent(tools=[DuckDuckGoSearchTool()], llm_engine=llm_engine)
+
+managed_web_agent = ManagedAgent(
+ agent=web_agent,
+ name="web_search",
+ description="Runs web searches for you. Give it your query as an argument."
+)
+
+manager_agent = ReactCodeAgent(
+ tools=[], llm_engine=llm_engine, managed_agents=[managed_web_agent]
+)
+
+manager_agent.run("Who is the CEO of Hugging Face?")
+```
+
+> [!TIP]
+> For an in-depth example of an efficient multi-agent implementation, see [how we pushed our multi-agent system to the top of the GAIA leaderboard](https://huggingface.co/blog/beating-gaia).
+
+
+## Advanced tool usage
+
+### Directly define a tool by subclassing Tool, and share it to the Hub
+
+Let's take again the tool example from main documentation, for which we had implemented a `tool` decorator.
+
+If you need to add variation, like custom attributes for your tool, you can build your tool following the fine-grained method: building a class that inherits from the [`Tool`] superclass.
+
+The custom tool needs:
+- An attribute `name`, which corresponds to the name of the tool itself. The name usually describes what the tool does. Since the code returns the model with the most downloads for a task, let's name it `model_download_counter`.
+- An attribute `description` is used to populate the agent's system prompt.
+- An `inputs` attribute, which is a dictionary with keys `"type"` and `"description"`. It contains information that helps the Python interpreter make educated choices about the input.
+- An `output_type` attribute, which specifies the output type.
+- A `forward` method which contains the inference code to be executed.
+
+The types for both `inputs` and `output_type` should be amongst [Pydantic formats](https://docs.pydantic.dev/latest/concepts/json_schema/#generating-json-schema).
+
+```python
+from transformers import Tool
+from huggingface_hub import list_models
+
+class HFModelDownloadsTool(Tool):
+ name = "model_download_counter"
+ description = """
+ This is a tool that returns the most downloaded model of a given task on the Hugging Face Hub.
+ It returns the name of the checkpoint."""
+
+ inputs = {
+ "task": {
+ "type": "string",
+ "description": "the task category (such as text-classification, depth-estimation, etc)",
+ }
+ }
+ output_type = "string"
+
+ def forward(self, task: str):
+ model = next(iter(list_models(filter=task, sort="downloads", direction=-1)))
+ return model.id
+```
+
+Now that the custom `HfModelDownloadsTool` class is ready, you can save it to a file named `model_downloads.py` and import it for use.
+
+
+```python
+from model_downloads import HFModelDownloadsTool
+
+tool = HFModelDownloadsTool()
+```
+
+You can also share your custom tool to the Hub by calling [`~Tool.push_to_hub`] on the tool. Make sure you've created a repository for it on the Hub and are using a token with read access.
+
+```python
+tool.push_to_hub("{your_username}/hf-model-downloads")
+```
+
+Load the tool with the [`~Tool.load_tool`] function and pass it to the `tools` parameter in your agent.
+
+```python
+from transformers import load_tool, CodeAgent
+
+model_download_tool = load_tool("m-ric/hf-model-downloads")
+```
+
+### Import a Space as a tool ๐
+
+You can directly import a Space from the Hub as a tool using the [`Tool.from_space`] method!
+
+You only need to provide the id of the Space on the Hub, its name, and a description that will help you agent understand what the tool does. Under the hood, this will use [`gradio-client`](https://pypi.org/project/gradio-client/) library to call the Space.
+
+For instance, let's import the [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) Space from the Hub and use it to generate an image.
+
+```
+from transformers import Tool
+
+image_generation_tool = Tool.from_space(
+ "black-forest-labs/FLUX.1-dev",
+ name="image_generator",
+ description="Generate an image from a prompt")
+
+image_generation_tool("A sunny beach")
+```
+And voilร , here's your image! ๐๏ธ
+
+
+
+Then you can use this tool just like any other tool. For example, let's improve the prompt `a rabbit wearing a space suit` and generate an image of it.
+
+```python
+from transformers import ReactCodeAgent
+
+agent = ReactCodeAgent(tools=[image_generation_tool])
+
+agent.run(
+ "Improve this prompt, then generate an image of it.", prompt='A rabbit wearing a space suit'
+)
+```
+
+```text
+=== Agent thoughts:
+improved_prompt could be "A bright blue space suit wearing rabbit, on the surface of the moon, under a bright orange sunset, with the Earth visible in the background"
+
+Now that I have improved the prompt, I can use the image generator tool to generate an image based on this prompt.
+>>> Agent is executing the code below:
+image = image_generator(prompt="A bright blue space suit wearing rabbit, on the surface of the moon, under a bright orange sunset, with the Earth visible in the background")
+final_answer(image)
+```
+
+
+
+How cool is this? ๐คฉ
+
+### Use gradio-tools
+
+[gradio-tools](https://github.com/freddyaboulton/gradio-tools) is a powerful library that allows using Hugging
+Face Spaces as tools. It supports many existing Spaces as well as custom Spaces.
+
+Transformers supports `gradio_tools` with the [`Tool.from_gradio`] method. For example, let's use the [`StableDiffusionPromptGeneratorTool`](https://github.com/freddyaboulton/gradio-tools/blob/main/gradio_tools/tools/prompt_generator.py) from `gradio-tools` toolkit for improving prompts to generate better images.
+
+Import and instantiate the tool, then pass it to the `Tool.from_gradio` method:
+
+```python
+from gradio_tools import StableDiffusionPromptGeneratorTool
+from transformers import Tool, load_tool, CodeAgent
+
+gradio_prompt_generator_tool = StableDiffusionPromptGeneratorTool()
+prompt_generator_tool = Tool.from_gradio(gradio_prompt_generator_tool)
+```
+
+> [!WARNING]
+> gradio-tools require *textual* inputs and outputs even when working with different modalities like image and audio objects. Image and audio inputs and outputs are currently incompatible.
+
+### Use LangChain tools
+
+We love Langchain and think it has a very compelling suite of tools.
+To import a tool from LangChain, use the `from_langchain()` method.
+
+Here is how you can use it to recreate the intro's search result using a LangChain web search tool.
+This tool will need `pip install google-search-results` to work properly.
+```python
+from langchain.agents import load_tools
+from transformers import Tool, ReactCodeAgent
+
+search_tool = Tool.from_langchain(load_tools(["serpapi"])[0])
+
+agent = ReactCodeAgent(tools=[search_tool])
+
+agent.run("How many more blocks (also denoted as layers) are in BERT base encoder compared to the encoder from the architecture proposed in Attention is All You Need?")
+```
+
+## Display your agent run in a cool Gradio interface
+
+You can leverage `gradio.Chatbot` to display your agent's thoughts using `stream_to_gradio`, here is an example:
+
+```py
+import gradio as gr
+from transformers import (
+ load_tool,
+ ReactCodeAgent,
+ HfApiEngine,
+ stream_to_gradio,
+)
+
+# Import tool from Hub
+image_generation_tool = load_tool("m-ric/text-to-image")
+
+llm_engine = HfApiEngine("meta-llama/Meta-Llama-3-70B-Instruct")
+
+# Initialize the agent with the image generation tool
+agent = ReactCodeAgent(tools=[image_generation_tool], llm_engine=llm_engine)
+
+
+def interact_with_agent(task):
+ messages = []
+ messages.append(gr.ChatMessage(role="user", content=task))
+ yield messages
+ for msg in stream_to_gradio(agent, task):
+ messages.append(msg)
+ yield messages + [
+ gr.ChatMessage(role="assistant", content="โณ Task not finished yet!")
+ ]
+ yield messages
+
+
+with gr.Blocks() as demo:
+ text_input = gr.Textbox(lines=1, label="Chat Message", value="Make me a picture of the Statue of Liberty.")
+ submit = gr.Button("Run illustrator agent!")
+ chatbot = gr.Chatbot(
+ label="Agent",
+ type="messages",
+ avatar_images=(
+ None,
+ "https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png",
+ ),
+ )
+ submit.click(interact_with_agent, [text_input], [chatbot])
+
+if __name__ == "__main__":
+ demo.launch()
+```
diff --git a/docs/source/index.md b/docs/source/index.md
new file mode 100644
index 0000000..28db7a0
--- /dev/null
+++ b/docs/source/index.md
@@ -0,0 +1,74 @@
+
+
+# Accelerate
+
+Accelerate is a library that enables the same PyTorch code to be run across any distributed configuration by adding just four lines of code! In short, training and inference at scale made simple, efficient and adaptable.
+
+```diff
++ from accelerate import Accelerator
++ accelerator = Accelerator()
+
++ model, optimizer, training_dataloader, scheduler = accelerator.prepare(
++ model, optimizer, training_dataloader, scheduler
++ )
+
+ for batch in training_dataloader:
+ optimizer.zero_grad()
+ inputs, targets = batch
+ inputs = inputs.to(device)
+ targets = targets.to(device)
+ outputs = model(inputs)
+ loss = loss_function(outputs, targets)
++ accelerator.backward(loss)
+ optimizer.step()
+ scheduler.step()
+```
+
+Built on `torch_xla` and `torch.distributed`, Accelerate takes care of the heavy lifting, so you don't have to write any custom code to adapt to these platforms.
+Convert existing codebases to utilize [DeepSpeed](usage_guides/deepspeed), perform [fully sharded data parallelism](usage_guides/fsdp), and have automatic support for mixed-precision training!
+
+
+
+ To get a better idea of this process, make sure to check out the [Tutorials](basic_tutorials/overview)!
+
+
+
+
+This code can then be launched on any system through Accelerate's CLI interface:
+```bash
+accelerate launch {my_script.py}
+```
+
+
diff --git a/docs/source/main_classes/agent.md b/docs/source/main_classes/agent.md
new file mode 100644
index 0000000..ed0486b
--- /dev/null
+++ b/docs/source/main_classes/agent.md
@@ -0,0 +1,167 @@
+
+
+# Agents & Tools
+
+
+
+Transformers Agents is an experimental API which is subject to change at any time. Results returned by the agents
+can vary as the APIs or underlying models are prone to change.
+
+
+
+To learn more about agents and tools make sure to read the [introductory guide](../transformers_agents). This page
+contains the API docs for the underlying classes.
+
+## Agents
+
+We provide two types of agents, based on the main [`Agent`] class:
+- [`CodeAgent`] acts in one shot, generating code to solve the task, then executes it at once.
+- [`ReactAgent`] acts step by step, each step consisting of one thought, then one tool call and execution. It has two classes:
+ - [`ReactJsonAgent`] writes its tool calls in JSON.
+ - [`ReactCodeAgent`] writes its tool calls in Python code.
+
+### Agent
+
+[[autodoc]] Agent
+
+### CodeAgent
+
+[[autodoc]] CodeAgent
+
+### React agents
+
+[[autodoc]] ReactAgent
+
+[[autodoc]] ReactJsonAgent
+
+[[autodoc]] ReactCodeAgent
+
+### ManagedAgent
+
+[[autodoc]] ManagedAgent
+
+## Tools
+
+### load_tool
+
+[[autodoc]] load_tool
+
+### tool
+
+[[autodoc]] tool
+
+### Tool
+
+[[autodoc]] Tool
+
+### Toolbox
+
+[[autodoc]] Toolbox
+
+### PipelineTool
+
+[[autodoc]] PipelineTool
+
+### launch_gradio_demo
+
+[[autodoc]] launch_gradio_demo
+
+### stream_to_gradio
+
+[[autodoc]] stream_to_gradio
+
+### ToolCollection
+
+[[autodoc]] ToolCollection
+
+## Engines
+
+You're free to create and use your own engines to be usable by the Agents framework.
+These engines have the following specification:
+1. Follow the [messages format](../chat_templating.md) for its input (`List[Dict[str, str]]`) and return a string.
+2. Stop generating outputs *before* the sequences passed in the argument `stop_sequences`
+
+### TransformersEngine
+
+For convenience, we have added a `TransformersEngine` that implements the points above, taking a pre-initialized `Pipeline` as input.
+
+```python
+>>> from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, TransformersEngine
+
+>>> model_name = "HuggingFaceTB/SmolLM-135M-Instruct"
+>>> tokenizer = AutoTokenizer.from_pretrained(model_name)
+>>> model = AutoModelForCausalLM.from_pretrained(model_name)
+
+>>> pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
+
+>>> engine = TransformersEngine(pipe)
+>>> engine([{"role": "user", "content": "Ok!"}], stop_sequences=["great"])
+
+"What a "
+```
+
+[[autodoc]] TransformersEngine
+
+### HfApiEngine
+
+The `HfApiEngine` is an engine that wraps an [HF Inference API](https://huggingface.co/docs/api-inference/index) client for the execution of the LLM.
+
+```python
+>>> from transformers import HfApiEngine
+
+>>> messages = [
+... {"role": "user", "content": "Hello, how are you?"},
+... {"role": "assistant", "content": "I'm doing great. How can I help you today?"},
+... {"role": "user", "content": "No need to help, take it easy."},
+... ]
+
+>>> HfApiEngine()(messages, stop_sequences=["conversation"])
+
+"That's very kind of you to say! It's always nice to have a relaxed "
+```
+
+[[autodoc]] HfApiEngine
+
+
+## Agent Types
+
+Agents can handle any type of object in-between tools; tools, being completely multimodal, can accept and return
+text, image, audio, video, among other types. In order to increase compatibility between tools, as well as to
+correctly render these returns in ipython (jupyter, colab, ipython notebooks, ...), we implement wrapper classes
+around these types.
+
+The wrapped objects should continue behaving as initially; a text object should still behave as a string, an image
+object should still behave as a `PIL.Image`.
+
+These types have three specific purposes:
+
+- Calling `to_raw` on the type should return the underlying object
+- Calling `to_string` on the type should return the object as a string: that can be the string in case of an `AgentText`
+ but will be the path of the serialized version of the object in other instances
+- Displaying it in an ipython kernel should display the object correctly
+
+### AgentText
+
+[[autodoc]] transformers.agents.agent_types.AgentText
+
+### AgentImage
+
+[[autodoc]] transformers.agents.agent_types.AgentImage
+
+### AgentAudio
+
+[[autodoc]] transformers.agents.agent_types.AgentAudio
diff --git a/docs/source/main_classes/backbones.md b/docs/source/main_classes/backbones.md
new file mode 100644
index 0000000..5f1fc1d
--- /dev/null
+++ b/docs/source/main_classes/backbones.md
@@ -0,0 +1,60 @@
+
+
+# Backbone
+
+A backbone is a model used for feature extraction for higher level computer vision tasks such as object detection and image classification. Transformers provides an [`AutoBackbone`] class for initializing a Transformers backbone from pretrained model weights, and two utility classes:
+
+* [`~utils.BackboneMixin`] enables initializing a backbone from Transformers or [timm](https://hf.co/docs/timm/index) and includes functions for returning the output features and indices.
+* [`~utils.BackboneConfigMixin`] sets the output features and indices of the backbone configuration.
+
+[timm](https://hf.co/docs/timm/index) models are loaded with the [`TimmBackbone`] and [`TimmBackboneConfig`] classes.
+
+Backbones are supported for the following models:
+
+* [BEiT](../model_doc/beit)
+* [BiT](../model_doc/bit)
+* [ConvNext](../model_doc/convnext)
+* [ConvNextV2](../model_doc/convnextv2)
+* [DiNAT](../model_doc/dinat)
+* [DINOV2](../model_doc/dinov2)
+* [FocalNet](../model_doc/focalnet)
+* [MaskFormer](../model_doc/maskformer)
+* [NAT](../model_doc/nat)
+* [ResNet](../model_doc/resnet)
+* [Swin Transformer](../model_doc/swin)
+* [Swin Transformer v2](../model_doc/swinv2)
+* [ViTDet](../model_doc/vitdet)
+
+## AutoBackbone
+
+[[autodoc]] AutoBackbone
+
+## BackboneMixin
+
+[[autodoc]] utils.BackboneMixin
+
+## BackboneConfigMixin
+
+[[autodoc]] utils.BackboneConfigMixin
+
+## TimmBackbone
+
+[[autodoc]] models.timm_backbone.TimmBackbone
+
+## TimmBackboneConfig
+
+[[autodoc]] models.timm_backbone.TimmBackboneConfig
diff --git a/docs/source/main_classes/callback.md b/docs/source/main_classes/callback.md
new file mode 100644
index 0000000..ee91737
--- /dev/null
+++ b/docs/source/main_classes/callback.md
@@ -0,0 +1,133 @@
+
+
+# Callbacks
+
+Callbacks are objects that can customize the behavior of the training loop in the PyTorch
+[`Trainer`] (this feature is not yet implemented in TensorFlow) that can inspect the training loop
+state (for progress reporting, logging on TensorBoard or other ML platforms...) and take decisions (like early
+stopping).
+
+Callbacks are "read only" pieces of code, apart from the [`TrainerControl`] object they return, they
+cannot change anything in the training loop. For customizations that require changes in the training loop, you should
+subclass [`Trainer`] and override the methods you need (see [trainer](trainer) for examples).
+
+By default, `TrainingArguments.report_to` is set to `"all"`, so a [`Trainer`] will use the following callbacks.
+
+- [`DefaultFlowCallback`] which handles the default behavior for logging, saving and evaluation.
+- [`PrinterCallback`] or [`ProgressCallback`] to display progress and print the
+ logs (the first one is used if you deactivate tqdm through the [`TrainingArguments`], otherwise
+ it's the second one).
+- [`~integrations.TensorBoardCallback`] if tensorboard is accessible (either through PyTorch >= 1.4
+ or tensorboardX).
+- [`~integrations.WandbCallback`] if [wandb](https://www.wandb.com/) is installed.
+- [`~integrations.CometCallback`] if [comet_ml](https://www.comet.com/site/) is installed.
+- [`~integrations.MLflowCallback`] if [mlflow](https://www.mlflow.org/) is installed.
+- [`~integrations.NeptuneCallback`] if [neptune](https://neptune.ai/) is installed.
+- [`~integrations.AzureMLCallback`] if [azureml-sdk](https://pypi.org/project/azureml-sdk/) is
+ installed.
+- [`~integrations.CodeCarbonCallback`] if [codecarbon](https://pypi.org/project/codecarbon/) is
+ installed.
+- [`~integrations.ClearMLCallback`] if [clearml](https://github.com/allegroai/clearml) is installed.
+- [`~integrations.DagsHubCallback`] if [dagshub](https://dagshub.com/) is installed.
+- [`~integrations.FlyteCallback`] if [flyte](https://flyte.org/) is installed.
+- [`~integrations.DVCLiveCallback`] if [dvclive](https://dvc.org/doc/dvclive) is installed.
+
+If a package is installed but you don't wish to use the accompanying integration, you can change `TrainingArguments.report_to` to a list of just those integrations you want to use (e.g. `["azure_ml", "wandb"]`).
+
+The main class that implements callbacks is [`TrainerCallback`]. It gets the
+[`TrainingArguments`] used to instantiate the [`Trainer`], can access that
+Trainer's internal state via [`TrainerState`], and can take some actions on the training loop via
+[`TrainerControl`].
+
+
+## Available Callbacks
+
+Here is the list of the available [`TrainerCallback`] in the library:
+
+[[autodoc]] integrations.CometCallback
+ - setup
+
+[[autodoc]] DefaultFlowCallback
+
+[[autodoc]] PrinterCallback
+
+[[autodoc]] ProgressCallback
+
+[[autodoc]] EarlyStoppingCallback
+
+[[autodoc]] integrations.TensorBoardCallback
+
+[[autodoc]] integrations.WandbCallback
+ - setup
+
+[[autodoc]] integrations.MLflowCallback
+ - setup
+
+[[autodoc]] integrations.AzureMLCallback
+
+[[autodoc]] integrations.CodeCarbonCallback
+
+[[autodoc]] integrations.NeptuneCallback
+
+[[autodoc]] integrations.ClearMLCallback
+
+[[autodoc]] integrations.DagsHubCallback
+
+[[autodoc]] integrations.FlyteCallback
+
+[[autodoc]] integrations.DVCLiveCallback
+ - setup
+
+## TrainerCallback
+
+[[autodoc]] TrainerCallback
+
+Here is an example of how to register a custom callback with the PyTorch [`Trainer`]:
+
+```python
+class MyCallback(TrainerCallback):
+ "A callback that prints a message at the beginning of training"
+
+ def on_train_begin(self, args, state, control, **kwargs):
+ print("Starting training")
+
+
+trainer = Trainer(
+ model,
+ args,
+ train_dataset=train_dataset,
+ eval_dataset=eval_dataset,
+ callbacks=[MyCallback], # We can either pass the callback class this way or an instance of it (MyCallback())
+)
+```
+
+Another way to register a callback is to call `trainer.add_callback()` as follows:
+
+```python
+trainer = Trainer(...)
+trainer.add_callback(MyCallback)
+# Alternatively, we can pass an instance of the callback class
+trainer.add_callback(MyCallback())
+```
+
+## TrainerState
+
+[[autodoc]] TrainerState
+
+## TrainerControl
+
+[[autodoc]] TrainerControl
diff --git a/docs/source/main_classes/configuration.md b/docs/source/main_classes/configuration.md
new file mode 100644
index 0000000..0cfef06
--- /dev/null
+++ b/docs/source/main_classes/configuration.md
@@ -0,0 +1,32 @@
+
+
+# Configuration
+
+The base class [`PretrainedConfig`] implements the common methods for loading/saving a configuration
+either from a local file or directory, or from a pretrained model configuration provided by the library (downloaded
+from HuggingFace's AWS S3 repository).
+
+Each derived config class implements model specific attributes. Common attributes present in all config classes are:
+`hidden_size`, `num_attention_heads`, and `num_hidden_layers`. Text models further implement:
+`vocab_size`.
+
+
+## PretrainedConfig
+
+[[autodoc]] PretrainedConfig
+ - push_to_hub
+ - all
diff --git a/docs/source/main_classes/data_collator.md b/docs/source/main_classes/data_collator.md
new file mode 100644
index 0000000..e704bb7
--- /dev/null
+++ b/docs/source/main_classes/data_collator.md
@@ -0,0 +1,73 @@
+
+
+# Data Collator
+
+Data collators are objects that will form a batch by using a list of dataset elements as input. These elements are of
+the same type as the elements of `train_dataset` or `eval_dataset`.
+
+To be able to build batches, data collators may apply some processing (like padding). Some of them (like
+[`DataCollatorForLanguageModeling`]) also apply some random data augmentation (like random masking)
+on the formed batch.
+
+Examples of use can be found in the [example scripts](../examples) or [example notebooks](../notebooks).
+
+
+## Default data collator
+
+[[autodoc]] data.data_collator.default_data_collator
+
+## DefaultDataCollator
+
+[[autodoc]] data.data_collator.DefaultDataCollator
+
+## DataCollatorWithPadding
+
+[[autodoc]] data.data_collator.DataCollatorWithPadding
+
+## DataCollatorForTokenClassification
+
+[[autodoc]] data.data_collator.DataCollatorForTokenClassification
+
+## DataCollatorForSeq2Seq
+
+[[autodoc]] data.data_collator.DataCollatorForSeq2Seq
+
+## DataCollatorForLanguageModeling
+
+[[autodoc]] data.data_collator.DataCollatorForLanguageModeling
+ - numpy_mask_tokens
+ - tf_mask_tokens
+ - torch_mask_tokens
+
+## DataCollatorForWholeWordMask
+
+[[autodoc]] data.data_collator.DataCollatorForWholeWordMask
+ - numpy_mask_tokens
+ - tf_mask_tokens
+ - torch_mask_tokens
+
+## DataCollatorForPermutationLanguageModeling
+
+[[autodoc]] data.data_collator.DataCollatorForPermutationLanguageModeling
+ - numpy_mask_tokens
+ - tf_mask_tokens
+ - torch_mask_tokens
+
+## DataCollatorWithFlattening
+
+[[autodoc]] data.data_collator.DataCollatorWithFlattening
+
diff --git a/docs/source/main_classes/deepspeed.md b/docs/source/main_classes/deepspeed.md
new file mode 100644
index 0000000..5863f66
--- /dev/null
+++ b/docs/source/main_classes/deepspeed.md
@@ -0,0 +1,32 @@
+
+
+# DeepSpeed
+
+[DeepSpeed](https://github.com/microsoft/DeepSpeed), powered by Zero Redundancy Optimizer (ZeRO), is an optimization library for training and fitting very large models onto a GPU. It is available in several ZeRO stages, where each stage progressively saves more GPU memory by partitioning the optimizer state, gradients, parameters, and enabling offloading to a CPU or NVMe. DeepSpeed is integrated with the [`Trainer`] class and most of the setup is automatically taken care of for you.
+
+However, if you want to use DeepSpeed without the [`Trainer`], Transformers provides a [`HfDeepSpeedConfig`] class.
+
+
+
+Learn more about using DeepSpeed with [`Trainer`] in the [DeepSpeed](../deepspeed) guide.
+
+
+
+## HfDeepSpeedConfig
+
+[[autodoc]] integrations.HfDeepSpeedConfig
+ - all
diff --git a/docs/source/main_classes/executorch.md b/docs/source/main_classes/executorch.md
new file mode 100644
index 0000000..3178085
--- /dev/null
+++ b/docs/source/main_classes/executorch.md
@@ -0,0 +1,33 @@
+
+
+
+# ExecuTorch
+
+[`ExecuTorch`](https://github.com/pytorch/executorch) is an end-to-end solution for enabling on-device inference capabilities across mobile and edge devices including wearables, embedded devices and microcontrollers. It is part of the PyTorch ecosystem and supports the deployment of PyTorch models with a focus on portability, productivity, and performance.
+
+ExecuTorch introduces well defined entry points to perform model, device, and/or use-case specific optimizations such as backend delegation, user-defined compiler transformations, memory planning, and more. The first step in preparing a PyTorch model for execution on an edge device using ExecuTorch is to export the model. This is achieved through the use of a PyTorch API called [`torch.export`](https://pytorch.org/docs/stable/export.html).
+
+
+## ExecuTorch Integration
+
+An integration point is being developed to ensure that ๐ค Transformers can be exported using `torch.export`. The goal of this integration is not only to enable export but also to ensure that the exported artifact can be further lowered and optimized to run efficiently in `ExecuTorch`, particularly for mobile and edge use cases.
+
+[[autodoc]] TorchExportableModuleWithStaticCache
+ - forward
+
+[[autodoc]] convert_and_export_with_cache
diff --git a/docs/source/main_classes/feature_extractor.md b/docs/source/main_classes/feature_extractor.md
new file mode 100644
index 0000000..d79c531
--- /dev/null
+++ b/docs/source/main_classes/feature_extractor.md
@@ -0,0 +1,39 @@
+
+
+# Feature Extractor
+
+A feature extractor is in charge of preparing input features for audio or vision models. This includes feature extraction from sequences, e.g., pre-processing audio files to generate Log-Mel Spectrogram features, feature extraction from images, e.g., cropping image files, but also padding, normalization, and conversion to NumPy, PyTorch, and TensorFlow tensors.
+
+
+## FeatureExtractionMixin
+
+[[autodoc]] feature_extraction_utils.FeatureExtractionMixin
+ - from_pretrained
+ - save_pretrained
+
+## SequenceFeatureExtractor
+
+[[autodoc]] SequenceFeatureExtractor
+ - pad
+
+## BatchFeature
+
+[[autodoc]] BatchFeature
+
+## ImageFeatureExtractionMixin
+
+[[autodoc]] image_utils.ImageFeatureExtractionMixin
diff --git a/docs/source/main_classes/image_processor.md b/docs/source/main_classes/image_processor.md
new file mode 100644
index 0000000..320916f
--- /dev/null
+++ b/docs/source/main_classes/image_processor.md
@@ -0,0 +1,82 @@
+
+
+# Image Processor
+
+An image processor is in charge of preparing input features for vision models and post processing their outputs. This includes transformations such as resizing, normalization, and conversion to PyTorch, TensorFlow, Flax and Numpy tensors. It may also include model specific post-processing such as converting logits to segmentation masks.
+
+Fast image processors are available for a few models and more will be added in the future. They are based on the [torchvision](https://pytorch.org/vision/stable/index.html) library and provide a significant speed-up, especially when processing on GPU.
+They have the same API as the base image processors and can be used as drop-in replacements.
+To use a fast image processor, you need to install the `torchvision` library, and set the `use_fast` argument to `True` when instantiating the image processor:
+
+```python
+from transformers import AutoImageProcessor
+
+processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50", use_fast=True)
+```
+
+When using a fast image processor, you can also set the `device` argument to specify the device on which the processing should be done. By default, the processing is done on the same device as the inputs if the inputs are tensors, or on the CPU otherwise.
+
+```python
+from torchvision.io import read_image
+from transformers import DetrImageProcessorFast
+
+images = read_image("image.jpg")
+processor = DetrImageProcessorFast.from_pretrained("facebook/detr-resnet-50")
+images_processed = processor(images, return_tensors="pt", device="cuda")
+```
+
+Here are some speed comparisons between the base and fast image processors for the `DETR` and `RT-DETR` models, and how they impact overall inference time:
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+These benchmarks were run on an [AWS EC2 g5.2xlarge instance](https://aws.amazon.com/ec2/instance-types/g5/), utilizing an NVIDIA A10G Tensor Core GPU.
+
+
+## ImageProcessingMixin
+
+[[autodoc]] image_processing_utils.ImageProcessingMixin
+ - from_pretrained
+ - save_pretrained
+
+## BatchFeature
+
+[[autodoc]] BatchFeature
+
+## BaseImageProcessor
+
+[[autodoc]] image_processing_utils.BaseImageProcessor
+
+
+## BaseImageProcessorFast
+
+[[autodoc]] image_processing_utils_fast.BaseImageProcessorFast
diff --git a/docs/source/main_classes/keras_callbacks.md b/docs/source/main_classes/keras_callbacks.md
new file mode 100644
index 0000000..c993230
--- /dev/null
+++ b/docs/source/main_classes/keras_callbacks.md
@@ -0,0 +1,28 @@
+
+
+# Keras callbacks
+
+When training a Transformers model with Keras, there are some library-specific callbacks available to automate common
+tasks:
+
+## KerasMetricCallback
+
+[[autodoc]] KerasMetricCallback
+
+## PushToHubCallback
+
+[[autodoc]] PushToHubCallback
diff --git a/docs/source/main_classes/logging.md b/docs/source/main_classes/logging.md
new file mode 100644
index 0000000..5cbdf9a
--- /dev/null
+++ b/docs/source/main_classes/logging.md
@@ -0,0 +1,119 @@
+
+
+# Logging
+
+๐ค Transformers has a centralized logging system, so that you can setup the verbosity of the library easily.
+
+Currently the default verbosity of the library is `WARNING`.
+
+To change the level of verbosity, just use one of the direct setters. For instance, here is how to change the verbosity
+to the INFO level.
+
+```python
+import transformers
+
+transformers.logging.set_verbosity_info()
+```
+
+You can also use the environment variable `TRANSFORMERS_VERBOSITY` to override the default verbosity. You can set it
+to one of the following: `debug`, `info`, `warning`, `error`, `critical`, `fatal`. For example:
+
+```bash
+TRANSFORMERS_VERBOSITY=error ./myprogram.py
+```
+
+Additionally, some `warnings` can be disabled by setting the environment variable
+`TRANSFORMERS_NO_ADVISORY_WARNINGS` to a true value, like *1*. This will disable any warning that is logged using
+[`logger.warning_advice`]. For example:
+
+```bash
+TRANSFORMERS_NO_ADVISORY_WARNINGS=1 ./myprogram.py
+```
+
+Here is an example of how to use the same logger as the library in your own module or script:
+
+```python
+from transformers.utils import logging
+
+logging.set_verbosity_info()
+logger = logging.get_logger("transformers")
+logger.info("INFO")
+logger.warning("WARN")
+```
+
+
+All the methods of this logging module are documented below, the main ones are
+[`logging.get_verbosity`] to get the current level of verbosity in the logger and
+[`logging.set_verbosity`] to set the verbosity to the level of your choice. In order (from the least
+verbose to the most verbose), those levels (with their corresponding int values in parenthesis) are:
+
+- `transformers.logging.CRITICAL` or `transformers.logging.FATAL` (int value, 50): only report the most
+ critical errors.
+- `transformers.logging.ERROR` (int value, 40): only report errors.
+- `transformers.logging.WARNING` or `transformers.logging.WARN` (int value, 30): only reports error and
+ warnings. This is the default level used by the library.
+- `transformers.logging.INFO` (int value, 20): reports error, warnings and basic information.
+- `transformers.logging.DEBUG` (int value, 10): report all information.
+
+By default, `tqdm` progress bars will be displayed during model download. [`logging.disable_progress_bar`] and [`logging.enable_progress_bar`] can be used to suppress or unsuppress this behavior.
+
+## `logging` vs `warnings`
+
+Python has two logging systems that are often used in conjunction: `logging`, which is explained above, and `warnings`,
+which allows further classification of warnings in specific buckets, e.g., `FutureWarning` for a feature or path
+that has already been deprecated and `DeprecationWarning` to indicate an upcoming deprecation.
+
+We use both in the `transformers` library. We leverage and adapt `logging`'s `captureWarnings` method to allow
+management of these warning messages by the verbosity setters above.
+
+What does that mean for developers of the library? We should respect the following heuristics:
+- `warnings` should be favored for developers of the library and libraries dependent on `transformers`
+- `logging` should be used for end-users of the library using it in every-day projects
+
+See reference of the `captureWarnings` method below.
+
+[[autodoc]] logging.captureWarnings
+
+## Base setters
+
+[[autodoc]] logging.set_verbosity_error
+
+[[autodoc]] logging.set_verbosity_warning
+
+[[autodoc]] logging.set_verbosity_info
+
+[[autodoc]] logging.set_verbosity_debug
+
+## Other functions
+
+[[autodoc]] logging.get_verbosity
+
+[[autodoc]] logging.set_verbosity
+
+[[autodoc]] logging.get_logger
+
+[[autodoc]] logging.enable_default_handler
+
+[[autodoc]] logging.disable_default_handler
+
+[[autodoc]] logging.enable_explicit_format
+
+[[autodoc]] logging.reset_format
+
+[[autodoc]] logging.enable_progress_bar
+
+[[autodoc]] logging.disable_progress_bar
diff --git a/docs/source/main_classes/model.md b/docs/source/main_classes/model.md
new file mode 100644
index 0000000..15345a7
--- /dev/null
+++ b/docs/source/main_classes/model.md
@@ -0,0 +1,73 @@
+
+
+# Models
+
+The base classes [`PreTrainedModel`], [`TFPreTrainedModel`], and
+[`FlaxPreTrainedModel`] implement the common methods for loading/saving a model either from a local
+file or directory, or from a pretrained model configuration provided by the library (downloaded from HuggingFace's AWS
+S3 repository).
+
+[`PreTrainedModel`] and [`TFPreTrainedModel`] also implement a few methods which
+are common among all the models to:
+
+- resize the input token embeddings when new tokens are added to the vocabulary
+- prune the attention heads of the model.
+
+The other methods that are common to each model are defined in [`~modeling_utils.ModuleUtilsMixin`]
+(for the PyTorch models) and [`~modeling_tf_utils.TFModuleUtilsMixin`] (for the TensorFlow models) or
+for text generation, [`~generation.GenerationMixin`] (for the PyTorch models),
+[`~generation.TFGenerationMixin`] (for the TensorFlow models) and
+[`~generation.FlaxGenerationMixin`] (for the Flax/JAX models).
+
+
+## PreTrainedModel
+
+[[autodoc]] PreTrainedModel
+ - push_to_hub
+ - all
+
+Custom models should also include a `_supports_assign_param_buffer`, which determines if superfast init can apply
+on the particular model. Signs that your model needs this are if `test_save_and_load_from_pretrained` fails. If so,
+set this to `False`.
+
+## ModuleUtilsMixin
+
+[[autodoc]] modeling_utils.ModuleUtilsMixin
+
+## TFPreTrainedModel
+
+[[autodoc]] TFPreTrainedModel
+ - push_to_hub
+ - all
+
+## TFModelUtilsMixin
+
+[[autodoc]] modeling_tf_utils.TFModelUtilsMixin
+
+## FlaxPreTrainedModel
+
+[[autodoc]] FlaxPreTrainedModel
+ - push_to_hub
+ - all
+
+## Pushing to the Hub
+
+[[autodoc]] utils.PushToHubMixin
+
+## Sharded checkpoints
+
+[[autodoc]] modeling_utils.load_sharded_checkpoint
diff --git a/docs/source/main_classes/onnx.md b/docs/source/main_classes/onnx.md
new file mode 100644
index 0000000..81d31c9
--- /dev/null
+++ b/docs/source/main_classes/onnx.md
@@ -0,0 +1,54 @@
+
+
+# Exporting ๐ค Transformers models to ONNX
+
+๐ค Transformers provides a `transformers.onnx` package that enables you to
+convert model checkpoints to an ONNX graph by leveraging configuration objects.
+
+See the [guide](../serialization) on exporting ๐ค Transformers models for more
+details.
+
+## ONNX Configurations
+
+We provide three abstract classes that you should inherit from, depending on the
+type of model architecture you wish to export:
+
+* Encoder-based models inherit from [`~onnx.config.OnnxConfig`]
+* Decoder-based models inherit from [`~onnx.config.OnnxConfigWithPast`]
+* Encoder-decoder models inherit from [`~onnx.config.OnnxSeq2SeqConfigWithPast`]
+
+### OnnxConfig
+
+[[autodoc]] onnx.config.OnnxConfig
+
+### OnnxConfigWithPast
+
+[[autodoc]] onnx.config.OnnxConfigWithPast
+
+### OnnxSeq2SeqConfigWithPast
+
+[[autodoc]] onnx.config.OnnxSeq2SeqConfigWithPast
+
+## ONNX Features
+
+Each ONNX configuration is associated with a set of _features_ that enable you
+to export models for different types of topologies or tasks.
+
+### FeaturesManager
+
+[[autodoc]] onnx.features.FeaturesManager
+
diff --git a/docs/source/main_classes/optimizer_schedules.md b/docs/source/main_classes/optimizer_schedules.md
new file mode 100644
index 0000000..9815b43
--- /dev/null
+++ b/docs/source/main_classes/optimizer_schedules.md
@@ -0,0 +1,79 @@
+
+
+# Optimization
+
+The `.optimization` module provides:
+
+- an optimizer with weight decay fixed that can be used to fine-tuned models, and
+- several schedules in the form of schedule objects that inherit from `_LRSchedule`:
+- a gradient accumulation class to accumulate the gradients of multiple batches
+
+## AdamW (PyTorch)
+
+[[autodoc]] AdamW
+
+## AdaFactor (PyTorch)
+
+[[autodoc]] Adafactor
+
+## AdamWeightDecay (TensorFlow)
+
+[[autodoc]] AdamWeightDecay
+
+[[autodoc]] create_optimizer
+
+## Schedules
+
+### Learning Rate Schedules (PyTorch)
+
+[[autodoc]] SchedulerType
+
+[[autodoc]] get_scheduler
+
+[[autodoc]] get_constant_schedule
+
+[[autodoc]] get_constant_schedule_with_warmup
+
+
+
+[[autodoc]] get_cosine_schedule_with_warmup
+
+
+
+[[autodoc]] get_cosine_with_hard_restarts_schedule_with_warmup
+
+
+
+[[autodoc]] get_linear_schedule_with_warmup
+
+
+
+[[autodoc]] get_polynomial_decay_schedule_with_warmup
+
+[[autodoc]] get_inverse_sqrt_schedule
+
+[[autodoc]] get_wsd_schedule
+
+### Warmup (TensorFlow)
+
+[[autodoc]] WarmUp
+
+## Gradient Strategies
+
+### GradientAccumulator (TensorFlow)
+
+[[autodoc]] GradientAccumulator
diff --git a/docs/source/main_classes/output.md b/docs/source/main_classes/output.md
new file mode 100644
index 0000000..300213d
--- /dev/null
+++ b/docs/source/main_classes/output.md
@@ -0,0 +1,321 @@
+
+
+# Model outputs
+
+All models have outputs that are instances of subclasses of [`~utils.ModelOutput`]. Those are
+data structures containing all the information returned by the model, but that can also be used as tuples or
+dictionaries.
+
+Let's see how this looks in an example:
+
+```python
+from transformers import BertTokenizer, BertForSequenceClassification
+import torch
+
+tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased")
+model = BertForSequenceClassification.from_pretrained("google-bert/bert-base-uncased")
+
+inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
+outputs = model(**inputs, labels=labels)
+```
+
+The `outputs` object is a [`~modeling_outputs.SequenceClassifierOutput`], as we can see in the
+documentation of that class below, it means it has an optional `loss`, a `logits`, an optional `hidden_states` and
+an optional `attentions` attribute. Here we have the `loss` since we passed along `labels`, but we don't have
+`hidden_states` and `attentions` because we didn't pass `output_hidden_states=True` or
+`output_attentions=True`.
+
+
+
+When passing `output_hidden_states=True` you may expect the `outputs.hidden_states[-1]` to match `outputs.last_hidden_state` exactly.
+However, this is not always the case. Some models apply normalization or subsequent process to the last hidden state when it's returned.
+
+
+
+
+You can access each attribute as you would usually do, and if that attribute has not been returned by the model, you
+will get `None`. Here for instance `outputs.loss` is the loss computed by the model, and `outputs.attentions` is
+`None`.
+
+When considering our `outputs` object as tuple, it only considers the attributes that don't have `None` values.
+Here for instance, it has two elements, `loss` then `logits`, so
+
+```python
+outputs[:2]
+```
+
+will return the tuple `(outputs.loss, outputs.logits)` for instance.
+
+When considering our `outputs` object as dictionary, it only considers the attributes that don't have `None`
+values. Here for instance, it has two keys that are `loss` and `logits`.
+
+We document here the generic model outputs that are used by more than one model type. Specific output types are
+documented on their corresponding model page.
+
+## ModelOutput
+
+[[autodoc]] utils.ModelOutput
+ - to_tuple
+
+## BaseModelOutput
+
+[[autodoc]] modeling_outputs.BaseModelOutput
+
+## BaseModelOutputWithPooling
+
+[[autodoc]] modeling_outputs.BaseModelOutputWithPooling
+
+## BaseModelOutputWithCrossAttentions
+
+[[autodoc]] modeling_outputs.BaseModelOutputWithCrossAttentions
+
+## BaseModelOutputWithPoolingAndCrossAttentions
+
+[[autodoc]] modeling_outputs.BaseModelOutputWithPoolingAndCrossAttentions
+
+## BaseModelOutputWithPast
+
+[[autodoc]] modeling_outputs.BaseModelOutputWithPast
+
+## BaseModelOutputWithPastAndCrossAttentions
+
+[[autodoc]] modeling_outputs.BaseModelOutputWithPastAndCrossAttentions
+
+## Seq2SeqModelOutput
+
+[[autodoc]] modeling_outputs.Seq2SeqModelOutput
+
+## CausalLMOutput
+
+[[autodoc]] modeling_outputs.CausalLMOutput
+
+## CausalLMOutputWithCrossAttentions
+
+[[autodoc]] modeling_outputs.CausalLMOutputWithCrossAttentions
+
+## CausalLMOutputWithPast
+
+[[autodoc]] modeling_outputs.CausalLMOutputWithPast
+
+## MaskedLMOutput
+
+[[autodoc]] modeling_outputs.MaskedLMOutput
+
+## Seq2SeqLMOutput
+
+[[autodoc]] modeling_outputs.Seq2SeqLMOutput
+
+## NextSentencePredictorOutput
+
+[[autodoc]] modeling_outputs.NextSentencePredictorOutput
+
+## SequenceClassifierOutput
+
+[[autodoc]] modeling_outputs.SequenceClassifierOutput
+
+## Seq2SeqSequenceClassifierOutput
+
+[[autodoc]] modeling_outputs.Seq2SeqSequenceClassifierOutput
+
+## MultipleChoiceModelOutput
+
+[[autodoc]] modeling_outputs.MultipleChoiceModelOutput
+
+## TokenClassifierOutput
+
+[[autodoc]] modeling_outputs.TokenClassifierOutput
+
+## QuestionAnsweringModelOutput
+
+[[autodoc]] modeling_outputs.QuestionAnsweringModelOutput
+
+## Seq2SeqQuestionAnsweringModelOutput
+
+[[autodoc]] modeling_outputs.Seq2SeqQuestionAnsweringModelOutput
+
+## Seq2SeqSpectrogramOutput
+
+[[autodoc]] modeling_outputs.Seq2SeqSpectrogramOutput
+
+## SemanticSegmenterOutput
+
+[[autodoc]] modeling_outputs.SemanticSegmenterOutput
+
+## ImageClassifierOutput
+
+[[autodoc]] modeling_outputs.ImageClassifierOutput
+
+## ImageClassifierOutputWithNoAttention
+
+[[autodoc]] modeling_outputs.ImageClassifierOutputWithNoAttention
+
+## DepthEstimatorOutput
+
+[[autodoc]] modeling_outputs.DepthEstimatorOutput
+
+## Wav2Vec2BaseModelOutput
+
+[[autodoc]] modeling_outputs.Wav2Vec2BaseModelOutput
+
+## XVectorOutput
+
+[[autodoc]] modeling_outputs.XVectorOutput
+
+## Seq2SeqTSModelOutput
+
+[[autodoc]] modeling_outputs.Seq2SeqTSModelOutput
+
+## Seq2SeqTSPredictionOutput
+
+[[autodoc]] modeling_outputs.Seq2SeqTSPredictionOutput
+
+## SampleTSPredictionOutput
+
+[[autodoc]] modeling_outputs.SampleTSPredictionOutput
+
+## TFBaseModelOutput
+
+[[autodoc]] modeling_tf_outputs.TFBaseModelOutput
+
+## TFBaseModelOutputWithPooling
+
+[[autodoc]] modeling_tf_outputs.TFBaseModelOutputWithPooling
+
+## TFBaseModelOutputWithPoolingAndCrossAttentions
+
+[[autodoc]] modeling_tf_outputs.TFBaseModelOutputWithPoolingAndCrossAttentions
+
+## TFBaseModelOutputWithPast
+
+[[autodoc]] modeling_tf_outputs.TFBaseModelOutputWithPast
+
+## TFBaseModelOutputWithPastAndCrossAttentions
+
+[[autodoc]] modeling_tf_outputs.TFBaseModelOutputWithPastAndCrossAttentions
+
+## TFSeq2SeqModelOutput
+
+[[autodoc]] modeling_tf_outputs.TFSeq2SeqModelOutput
+
+## TFCausalLMOutput
+
+[[autodoc]] modeling_tf_outputs.TFCausalLMOutput
+
+## TFCausalLMOutputWithCrossAttentions
+
+[[autodoc]] modeling_tf_outputs.TFCausalLMOutputWithCrossAttentions
+
+## TFCausalLMOutputWithPast
+
+[[autodoc]] modeling_tf_outputs.TFCausalLMOutputWithPast
+
+## TFMaskedLMOutput
+
+[[autodoc]] modeling_tf_outputs.TFMaskedLMOutput
+
+## TFSeq2SeqLMOutput
+
+[[autodoc]] modeling_tf_outputs.TFSeq2SeqLMOutput
+
+## TFNextSentencePredictorOutput
+
+[[autodoc]] modeling_tf_outputs.TFNextSentencePredictorOutput
+
+## TFSequenceClassifierOutput
+
+[[autodoc]] modeling_tf_outputs.TFSequenceClassifierOutput
+
+## TFSeq2SeqSequenceClassifierOutput
+
+[[autodoc]] modeling_tf_outputs.TFSeq2SeqSequenceClassifierOutput
+
+## TFMultipleChoiceModelOutput
+
+[[autodoc]] modeling_tf_outputs.TFMultipleChoiceModelOutput
+
+## TFTokenClassifierOutput
+
+[[autodoc]] modeling_tf_outputs.TFTokenClassifierOutput
+
+## TFQuestionAnsweringModelOutput
+
+[[autodoc]] modeling_tf_outputs.TFQuestionAnsweringModelOutput
+
+## TFSeq2SeqQuestionAnsweringModelOutput
+
+[[autodoc]] modeling_tf_outputs.TFSeq2SeqQuestionAnsweringModelOutput
+
+## FlaxBaseModelOutput
+
+[[autodoc]] modeling_flax_outputs.FlaxBaseModelOutput
+
+## FlaxBaseModelOutputWithPast
+
+[[autodoc]] modeling_flax_outputs.FlaxBaseModelOutputWithPast
+
+## FlaxBaseModelOutputWithPooling
+
+[[autodoc]] modeling_flax_outputs.FlaxBaseModelOutputWithPooling
+
+## FlaxBaseModelOutputWithPastAndCrossAttentions
+
+[[autodoc]] modeling_flax_outputs.FlaxBaseModelOutputWithPastAndCrossAttentions
+
+## FlaxSeq2SeqModelOutput
+
+[[autodoc]] modeling_flax_outputs.FlaxSeq2SeqModelOutput
+
+## FlaxCausalLMOutputWithCrossAttentions
+
+[[autodoc]] modeling_flax_outputs.FlaxCausalLMOutputWithCrossAttentions
+
+## FlaxMaskedLMOutput
+
+[[autodoc]] modeling_flax_outputs.FlaxMaskedLMOutput
+
+## FlaxSeq2SeqLMOutput
+
+[[autodoc]] modeling_flax_outputs.FlaxSeq2SeqLMOutput
+
+## FlaxNextSentencePredictorOutput
+
+[[autodoc]] modeling_flax_outputs.FlaxNextSentencePredictorOutput
+
+## FlaxSequenceClassifierOutput
+
+[[autodoc]] modeling_flax_outputs.FlaxSequenceClassifierOutput
+
+## FlaxSeq2SeqSequenceClassifierOutput
+
+[[autodoc]] modeling_flax_outputs.FlaxSeq2SeqSequenceClassifierOutput
+
+## FlaxMultipleChoiceModelOutput
+
+[[autodoc]] modeling_flax_outputs.FlaxMultipleChoiceModelOutput
+
+## FlaxTokenClassifierOutput
+
+[[autodoc]] modeling_flax_outputs.FlaxTokenClassifierOutput
+
+## FlaxQuestionAnsweringModelOutput
+
+[[autodoc]] modeling_flax_outputs.FlaxQuestionAnsweringModelOutput
+
+## FlaxSeq2SeqQuestionAnsweringModelOutput
+
+[[autodoc]] modeling_flax_outputs.FlaxSeq2SeqQuestionAnsweringModelOutput
diff --git a/docs/source/main_classes/pipelines.md b/docs/source/main_classes/pipelines.md
new file mode 100644
index 0000000..59e474f
--- /dev/null
+++ b/docs/source/main_classes/pipelines.md
@@ -0,0 +1,501 @@
+
+
+# Pipelines
+
+The pipelines are a great and easy way to use models for inference. These pipelines are objects that abstract most of
+the complex code from the library, offering a simple API dedicated to several tasks, including Named Entity
+Recognition, Masked Language Modeling, Sentiment Analysis, Feature Extraction and Question Answering. See the
+[task summary](../task_summary) for examples of use.
+
+There are two categories of pipeline abstractions to be aware about:
+
+- The [`pipeline`] which is the most powerful object encapsulating all other pipelines.
+- Task-specific pipelines are available for [audio](#audio), [computer vision](#computer-vision), [natural language processing](#natural-language-processing), and [multimodal](#multimodal) tasks.
+
+## The pipeline abstraction
+
+The *pipeline* abstraction is a wrapper around all the other available pipelines. It is instantiated as any other
+pipeline but can provide additional quality of life.
+
+Simple call on one item:
+
+```python
+>>> pipe = pipeline("text-classification")
+>>> pipe("This restaurant is awesome")
+[{'label': 'POSITIVE', 'score': 0.9998743534088135}]
+```
+
+If you want to use a specific model from the [hub](https://huggingface.co) you can ignore the task if the model on
+the hub already defines it:
+
+```python
+>>> pipe = pipeline(model="FacebookAI/roberta-large-mnli")
+>>> pipe("This restaurant is awesome")
+[{'label': 'NEUTRAL', 'score': 0.7313136458396912}]
+```
+
+To call a pipeline on many items, you can call it with a *list*.
+
+```python
+>>> pipe = pipeline("text-classification")
+>>> pipe(["This restaurant is awesome", "This restaurant is awful"])
+[{'label': 'POSITIVE', 'score': 0.9998743534088135},
+ {'label': 'NEGATIVE', 'score': 0.9996669292449951}]
+```
+
+To iterate over full datasets it is recommended to use a `dataset` directly. This means you don't need to allocate
+the whole dataset at once, nor do you need to do batching yourself. This should work just as fast as custom loops on
+GPU. If it doesn't don't hesitate to create an issue.
+
+```python
+import datasets
+from transformers import pipeline
+from transformers.pipelines.pt_utils import KeyDataset
+from tqdm.auto import tqdm
+
+pipe = pipeline("automatic-speech-recognition", model="facebook/wav2vec2-base-960h", device=0)
+dataset = datasets.load_dataset("superb", name="asr", split="test")
+
+# KeyDataset (only *pt*) will simply return the item in the dict returned by the dataset item
+# as we're not interested in the *target* part of the dataset. For sentence pair use KeyPairDataset
+for out in tqdm(pipe(KeyDataset(dataset, "file"))):
+ print(out)
+ # {"text": "NUMBER TEN FRESH NELLY IS WAITING ON YOU GOOD NIGHT HUSBAND"}
+ # {"text": ....}
+ # ....
+```
+
+For ease of use, a generator is also possible:
+
+
+```python
+from transformers import pipeline
+
+pipe = pipeline("text-classification")
+
+
+def data():
+ while True:
+ # This could come from a dataset, a database, a queue or HTTP request
+ # in a server
+ # Caveat: because this is iterative, you cannot use `num_workers > 1` variable
+ # to use multiple threads to preprocess data. You can still have 1 thread that
+ # does the preprocessing while the main runs the big inference
+ yield "This is a test"
+
+
+for out in pipe(data()):
+ print(out)
+ # {"text": "NUMBER TEN FRESH NELLY IS WAITING ON YOU GOOD NIGHT HUSBAND"}
+ # {"text": ....}
+ # ....
+```
+
+[[autodoc]] pipeline
+
+## Pipeline batching
+
+All pipelines can use batching. This will work
+whenever the pipeline uses its streaming ability (so when passing lists or `Dataset` or `generator`).
+
+```python
+from transformers import pipeline
+from transformers.pipelines.pt_utils import KeyDataset
+import datasets
+
+dataset = datasets.load_dataset("imdb", name="plain_text", split="unsupervised")
+pipe = pipeline("text-classification", device=0)
+for out in pipe(KeyDataset(dataset, "text"), batch_size=8, truncation="only_first"):
+ print(out)
+ # [{'label': 'POSITIVE', 'score': 0.9998743534088135}]
+ # Exactly the same output as before, but the content are passed
+ # as batches to the model
+```
+
+
+
+However, this is not automatically a win for performance. It can be either a 10x speedup or 5x slowdown depending
+on hardware, data and the actual model being used.
+
+Example where it's mostly a speedup:
+
+
+
+```python
+from transformers import pipeline
+from torch.utils.data import Dataset
+from tqdm.auto import tqdm
+
+pipe = pipeline("text-classification", device=0)
+
+
+class MyDataset(Dataset):
+ def __len__(self):
+ return 5000
+
+ def __getitem__(self, i):
+ return "This is a test"
+
+
+dataset = MyDataset()
+
+for batch_size in [1, 8, 64, 256]:
+ print("-" * 30)
+ print(f"Streaming batch_size={batch_size}")
+ for out in tqdm(pipe(dataset, batch_size=batch_size), total=len(dataset)):
+ pass
+```
+
+```
+# On GTX 970
+------------------------------
+Streaming no batching
+100%|โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ| 5000/5000 [00:26<00:00, 187.52it/s]
+------------------------------
+Streaming batch_size=8
+100%|โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ| 5000/5000 [00:04<00:00, 1205.95it/s]
+------------------------------
+Streaming batch_size=64
+100%|โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ| 5000/5000 [00:02<00:00, 2478.24it/s]
+------------------------------
+Streaming batch_size=256
+100%|โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ| 5000/5000 [00:01<00:00, 2554.43it/s]
+(diminishing returns, saturated the GPU)
+```
+
+Example where it's most a slowdown:
+
+```python
+class MyDataset(Dataset):
+ def __len__(self):
+ return 5000
+
+ def __getitem__(self, i):
+ if i % 64 == 0:
+ n = 100
+ else:
+ n = 1
+ return "This is a test" * n
+```
+
+This is a occasional very long sentence compared to the other. In that case, the **whole** batch will need to be 400
+tokens long, so the whole batch will be [64, 400] instead of [64, 4], leading to the high slowdown. Even worse, on
+bigger batches, the program simply crashes.
+
+
+```
+------------------------------
+Streaming no batching
+100%|โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ| 1000/1000 [00:05<00:00, 183.69it/s]
+------------------------------
+Streaming batch_size=8
+100%|โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ| 1000/1000 [00:03<00:00, 265.74it/s]
+------------------------------
+Streaming batch_size=64
+100%|โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ| 1000/1000 [00:26<00:00, 37.80it/s]
+------------------------------
+Streaming batch_size=256
+ 0%| | 0/1000 [00:00, ?it/s]
+Traceback (most recent call last):
+ File "/home/nicolas/src/transformers/test.py", line 42, in
+ for out in tqdm(pipe(dataset, batch_size=256), total=len(dataset)):
+....
+ q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head)
+RuntimeError: CUDA out of memory. Tried to allocate 376.00 MiB (GPU 0; 3.95 GiB total capacity; 1.72 GiB already allocated; 354.88 MiB free; 2.46 GiB reserved in total by PyTorch)
+```
+
+There are no good (general) solutions for this problem, and your mileage may vary depending on your use cases. Rule of
+thumb:
+
+For users, a rule of thumb is:
+
+- **Measure performance on your load, with your hardware. Measure, measure, and keep measuring. Real numbers are the
+ only way to go.**
+- If you are latency constrained (live product doing inference), don't batch.
+- If you are using CPU, don't batch.
+- If you are using throughput (you want to run your model on a bunch of static data), on GPU, then:
+
+ - If you have no clue about the size of the sequence_length ("natural" data), by default don't batch, measure and
+ try tentatively to add it, add OOM checks to recover when it will fail (and it will at some point if you don't
+ control the sequence_length.)
+ - If your sequence_length is super regular, then batching is more likely to be VERY interesting, measure and push
+ it until you get OOMs.
+ - The larger the GPU the more likely batching is going to be more interesting
+- As soon as you enable batching, make sure you can handle OOMs nicely.
+
+## Pipeline chunk batching
+
+`zero-shot-classification` and `question-answering` are slightly specific in the sense, that a single input might yield
+multiple forward pass of a model. Under normal circumstances, this would yield issues with `batch_size` argument.
+
+In order to circumvent this issue, both of these pipelines are a bit specific, they are `ChunkPipeline` instead of
+regular `Pipeline`. In short:
+
+
+```python
+preprocessed = pipe.preprocess(inputs)
+model_outputs = pipe.forward(preprocessed)
+outputs = pipe.postprocess(model_outputs)
+```
+
+Now becomes:
+
+
+```python
+all_model_outputs = []
+for preprocessed in pipe.preprocess(inputs):
+ model_outputs = pipe.forward(preprocessed)
+ all_model_outputs.append(model_outputs)
+outputs = pipe.postprocess(all_model_outputs)
+```
+
+This should be very transparent to your code because the pipelines are used in
+the same way.
+
+This is a simplified view, since the pipeline can handle automatically the batch to ! Meaning you don't have to care
+about how many forward passes you inputs are actually going to trigger, you can optimize the `batch_size`
+independently of the inputs. The caveats from the previous section still apply.
+
+## Pipeline FP16 inference
+Models can be run in FP16 which can be significantly faster on GPU while saving memory. Most models will not suffer noticeable performance loss from this. The larger the model, the less likely that it will.
+
+To enable FP16 inference, you can simply pass `torch_dtype=torch.float16` or `torch_dtype='float16'` to the pipeline constructor. Note that this only works for models with a PyTorch backend. Your inputs will be converted to FP16 internally.
+
+## Pipeline custom code
+
+If you want to override a specific pipeline.
+
+Don't hesitate to create an issue for your task at hand, the goal of the pipeline is to be easy to use and support most
+cases, so `transformers` could maybe support your use case.
+
+
+If you want to try simply you can:
+
+- Subclass your pipeline of choice
+
+```python
+class MyPipeline(TextClassificationPipeline):
+ def postprocess():
+ # Your code goes here
+ scores = scores * 100
+ # And here
+
+
+my_pipeline = MyPipeline(model=model, tokenizer=tokenizer, ...)
+# or if you use *pipeline* function, then:
+my_pipeline = pipeline(model="xxxx", pipeline_class=MyPipeline)
+```
+
+That should enable you to do all the custom code you want.
+
+
+## Implementing a pipeline
+
+[Implementing a new pipeline](../add_new_pipeline)
+
+## Audio
+
+Pipelines available for audio tasks include the following.
+
+### AudioClassificationPipeline
+
+[[autodoc]] AudioClassificationPipeline
+ - __call__
+ - all
+
+### AutomaticSpeechRecognitionPipeline
+
+[[autodoc]] AutomaticSpeechRecognitionPipeline
+ - __call__
+ - all
+
+### TextToAudioPipeline
+
+[[autodoc]] TextToAudioPipeline
+ - __call__
+ - all
+
+
+### ZeroShotAudioClassificationPipeline
+
+[[autodoc]] ZeroShotAudioClassificationPipeline
+ - __call__
+ - all
+
+## Computer vision
+
+Pipelines available for computer vision tasks include the following.
+
+### DepthEstimationPipeline
+[[autodoc]] DepthEstimationPipeline
+ - __call__
+ - all
+
+### ImageClassificationPipeline
+
+[[autodoc]] ImageClassificationPipeline
+ - __call__
+ - all
+
+### ImageSegmentationPipeline
+
+[[autodoc]] ImageSegmentationPipeline
+ - __call__
+ - all
+
+### ImageToImagePipeline
+
+[[autodoc]] ImageToImagePipeline
+ - __call__
+ - all
+
+### ObjectDetectionPipeline
+
+[[autodoc]] ObjectDetectionPipeline
+ - __call__
+ - all
+
+### VideoClassificationPipeline
+
+[[autodoc]] VideoClassificationPipeline
+ - __call__
+ - all
+
+### ZeroShotImageClassificationPipeline
+
+[[autodoc]] ZeroShotImageClassificationPipeline
+ - __call__
+ - all
+
+### ZeroShotObjectDetectionPipeline
+
+[[autodoc]] ZeroShotObjectDetectionPipeline
+ - __call__
+ - all
+
+## Natural Language Processing
+
+Pipelines available for natural language processing tasks include the following.
+
+### FillMaskPipeline
+
+[[autodoc]] FillMaskPipeline
+ - __call__
+ - all
+
+### QuestionAnsweringPipeline
+
+[[autodoc]] QuestionAnsweringPipeline
+ - __call__
+ - all
+
+### SummarizationPipeline
+
+[[autodoc]] SummarizationPipeline
+ - __call__
+ - all
+
+### TableQuestionAnsweringPipeline
+
+[[autodoc]] TableQuestionAnsweringPipeline
+ - __call__
+
+### TextClassificationPipeline
+
+[[autodoc]] TextClassificationPipeline
+ - __call__
+ - all
+
+### TextGenerationPipeline
+
+[[autodoc]] TextGenerationPipeline
+ - __call__
+ - all
+
+### Text2TextGenerationPipeline
+
+[[autodoc]] Text2TextGenerationPipeline
+ - __call__
+ - all
+
+### TokenClassificationPipeline
+
+[[autodoc]] TokenClassificationPipeline
+ - __call__
+ - all
+
+### TranslationPipeline
+
+[[autodoc]] TranslationPipeline
+ - __call__
+ - all
+
+### ZeroShotClassificationPipeline
+
+[[autodoc]] ZeroShotClassificationPipeline
+ - __call__
+ - all
+
+## Multimodal
+
+Pipelines available for multimodal tasks include the following.
+
+### DocumentQuestionAnsweringPipeline
+
+[[autodoc]] DocumentQuestionAnsweringPipeline
+ - __call__
+ - all
+
+### FeatureExtractionPipeline
+
+[[autodoc]] FeatureExtractionPipeline
+ - __call__
+ - all
+
+### ImageFeatureExtractionPipeline
+
+[[autodoc]] ImageFeatureExtractionPipeline
+ - __call__
+ - all
+
+### ImageToTextPipeline
+
+[[autodoc]] ImageToTextPipeline
+ - __call__
+ - all
+
+### ImageTextToTextPipeline
+
+[[autodoc]] ImageTextToTextPipeline
+ - __call__
+ - all
+
+### MaskGenerationPipeline
+
+[[autodoc]] MaskGenerationPipeline
+ - __call__
+ - all
+
+### VisualQuestionAnsweringPipeline
+
+[[autodoc]] VisualQuestionAnsweringPipeline
+ - __call__
+ - all
+
+## Parent class: `Pipeline`
+
+[[autodoc]] Pipeline
diff --git a/docs/source/main_classes/processors.md b/docs/source/main_classes/processors.md
new file mode 100644
index 0000000..5e943fc
--- /dev/null
+++ b/docs/source/main_classes/processors.md
@@ -0,0 +1,163 @@
+
+
+# Processors
+
+Processors can mean two different things in the Transformers library:
+- the objects that pre-process inputs for multi-modal models such as [Wav2Vec2](../model_doc/wav2vec2) (speech and text)
+ or [CLIP](../model_doc/clip) (text and vision)
+- deprecated objects that were used in older versions of the library to preprocess data for GLUE or SQUAD.
+
+## Multi-modal processors
+
+Any multi-modal model will require an object to encode or decode the data that groups several modalities (among text,
+vision and audio). This is handled by objects called processors, which group together two or more processing objects
+such as tokenizers (for the text modality), image processors (for vision) and feature extractors (for audio).
+
+Those processors inherit from the following base class that implements the saving and loading functionality:
+
+[[autodoc]] ProcessorMixin
+
+## Deprecated processors
+
+All processors follow the same architecture which is that of the
+[`~data.processors.utils.DataProcessor`]. The processor returns a list of
+[`~data.processors.utils.InputExample`]. These
+[`~data.processors.utils.InputExample`] can be converted to
+[`~data.processors.utils.InputFeatures`] in order to be fed to the model.
+
+[[autodoc]] data.processors.utils.DataProcessor
+
+[[autodoc]] data.processors.utils.InputExample
+
+[[autodoc]] data.processors.utils.InputFeatures
+
+## GLUE
+
+[General Language Understanding Evaluation (GLUE)](https://gluebenchmark.com/) is a benchmark that evaluates the
+performance of models across a diverse set of existing NLU tasks. It was released together with the paper [GLUE: A
+multi-task benchmark and analysis platform for natural language understanding](https://openreview.net/pdf?id=rJ4km2R5t7)
+
+This library hosts a total of 10 processors for the following tasks: MRPC, MNLI, MNLI (mismatched), CoLA, SST2, STSB,
+QQP, QNLI, RTE and WNLI.
+
+Those processors are:
+
+- [`~data.processors.utils.MrpcProcessor`]
+- [`~data.processors.utils.MnliProcessor`]
+- [`~data.processors.utils.MnliMismatchedProcessor`]
+- [`~data.processors.utils.Sst2Processor`]
+- [`~data.processors.utils.StsbProcessor`]
+- [`~data.processors.utils.QqpProcessor`]
+- [`~data.processors.utils.QnliProcessor`]
+- [`~data.processors.utils.RteProcessor`]
+- [`~data.processors.utils.WnliProcessor`]
+
+Additionally, the following method can be used to load values from a data file and convert them to a list of
+[`~data.processors.utils.InputExample`].
+
+[[autodoc]] data.processors.glue.glue_convert_examples_to_features
+
+
+## XNLI
+
+[The Cross-Lingual NLI Corpus (XNLI)](https://www.nyu.edu/projects/bowman/xnli/) is a benchmark that evaluates the
+quality of cross-lingual text representations. XNLI is crowd-sourced dataset based on [*MultiNLI*](http://www.nyu.edu/projects/bowman/multinli/): pairs of text are labeled with textual entailment annotations for 15
+different languages (including both high-resource language such as English and low-resource languages such as Swahili).
+
+It was released together with the paper [XNLI: Evaluating Cross-lingual Sentence Representations](https://arxiv.org/abs/1809.05053)
+
+This library hosts the processor to load the XNLI data:
+
+- [`~data.processors.utils.XnliProcessor`]
+
+Please note that since the gold labels are available on the test set, evaluation is performed on the test set.
+
+An example using these processors is given in the [run_xnli.py](https://github.com/huggingface/transformers/tree/main/examples/pytorch/text-classification/run_xnli.py) script.
+
+
+## SQuAD
+
+[The Stanford Question Answering Dataset (SQuAD)](https://rajpurkar.github.io/SQuAD-explorer//) is a benchmark that
+evaluates the performance of models on question answering. Two versions are available, v1.1 and v2.0. The first version
+(v1.1) was released together with the paper [SQuAD: 100,000+ Questions for Machine Comprehension of Text](https://arxiv.org/abs/1606.05250). The second version (v2.0) was released alongside the paper [Know What You Don't
+Know: Unanswerable Questions for SQuAD](https://arxiv.org/abs/1806.03822).
+
+This library hosts a processor for each of the two versions:
+
+### Processors
+
+Those processors are:
+
+- [`~data.processors.utils.SquadV1Processor`]
+- [`~data.processors.utils.SquadV2Processor`]
+
+They both inherit from the abstract class [`~data.processors.utils.SquadProcessor`]
+
+[[autodoc]] data.processors.squad.SquadProcessor
+ - all
+
+Additionally, the following method can be used to convert SQuAD examples into
+[`~data.processors.utils.SquadFeatures`] that can be used as model inputs.
+
+[[autodoc]] data.processors.squad.squad_convert_examples_to_features
+
+
+These processors as well as the aforementioned method can be used with files containing the data as well as with the
+*tensorflow_datasets* package. Examples are given below.
+
+
+### Example usage
+
+Here is an example using the processors as well as the conversion method using data files:
+
+```python
+# Loading a V2 processor
+processor = SquadV2Processor()
+examples = processor.get_dev_examples(squad_v2_data_dir)
+
+# Loading a V1 processor
+processor = SquadV1Processor()
+examples = processor.get_dev_examples(squad_v1_data_dir)
+
+features = squad_convert_examples_to_features(
+ examples=examples,
+ tokenizer=tokenizer,
+ max_seq_length=max_seq_length,
+ doc_stride=args.doc_stride,
+ max_query_length=max_query_length,
+ is_training=not evaluate,
+)
+```
+
+Using *tensorflow_datasets* is as easy as using a data file:
+
+```python
+# tensorflow_datasets only handle Squad V1.
+tfds_examples = tfds.load("squad")
+examples = SquadV1Processor().get_examples_from_dataset(tfds_examples, evaluate=evaluate)
+
+features = squad_convert_examples_to_features(
+ examples=examples,
+ tokenizer=tokenizer,
+ max_seq_length=max_seq_length,
+ doc_stride=args.doc_stride,
+ max_query_length=max_query_length,
+ is_training=not evaluate,
+)
+```
+
+Another example using these processors is given in the [run_squad.py](https://github.com/huggingface/transformers/tree/main/examples/legacy/question-answering/run_squad.py) script.
diff --git a/docs/source/main_classes/quantization.md b/docs/source/main_classes/quantization.md
new file mode 100755
index 0000000..3f44569
--- /dev/null
+++ b/docs/source/main_classes/quantization.md
@@ -0,0 +1,74 @@
+
+
+# Quantization
+
+Quantization techniques reduce memory and computational costs by representing weights and activations with lower-precision data types like 8-bit integers (int8). This enables loading larger models you normally wouldn't be able to fit into memory, and speeding up inference. Transformers supports the AWQ and GPTQ quantization algorithms and it supports 8-bit and 4-bit quantization with bitsandbytes.
+
+Quantization techniques that aren't supported in Transformers can be added with the [`HfQuantizer`] class.
+
+
+
+Learn how to quantize models in the [Quantization](../quantization) guide.
+
+
+
+## QuantoConfig
+
+[[autodoc]] QuantoConfig
+
+## AqlmConfig
+
+[[autodoc]] AqlmConfig
+
+## AwqConfig
+
+[[autodoc]] AwqConfig
+
+## EetqConfig
+[[autodoc]] EetqConfig
+
+## GPTQConfig
+
+[[autodoc]] GPTQConfig
+
+## BitsAndBytesConfig
+
+[[autodoc]] BitsAndBytesConfig
+
+## HfQuantizer
+
+[[autodoc]] quantizers.base.HfQuantizer
+
+## HqqConfig
+
+[[autodoc]] HqqConfig
+
+## FbgemmFp8Config
+
+[[autodoc]] FbgemmFp8Config
+
+## CompressedTensorsConfig
+
+[[autodoc]] CompressedTensorsConfig
+
+## TorchAoConfig
+
+[[autodoc]] TorchAoConfig
+
+## BitNetConfig
+
+[[autodoc]] BitNetConfig
diff --git a/docs/source/main_classes/text_generation.md b/docs/source/main_classes/text_generation.md
new file mode 100644
index 0000000..76a0f13
--- /dev/null
+++ b/docs/source/main_classes/text_generation.md
@@ -0,0 +1,59 @@
+
+
+# Generation
+
+Each framework has a generate method for text generation implemented in their respective `GenerationMixin` class:
+
+- PyTorch [`~generation.GenerationMixin.generate`] is implemented in [`~generation.GenerationMixin`].
+- TensorFlow [`~generation.TFGenerationMixin.generate`] is implemented in [`~generation.TFGenerationMixin`].
+- Flax/JAX [`~generation.FlaxGenerationMixin.generate`] is implemented in [`~generation.FlaxGenerationMixin`].
+
+Regardless of your framework of choice, you can parameterize the generate method with a [`~generation.GenerationConfig`]
+class instance. Please refer to this class for the complete list of generation parameters, which control the behavior
+of the generation method.
+
+To learn how to inspect a model's generation configuration, what are the defaults, how to change the parameters ad hoc,
+and how to create and save a customized generation configuration, refer to the
+[text generation strategies guide](../generation_strategies). The guide also explains how to use related features,
+like token streaming.
+
+## GenerationConfig
+
+[[autodoc]] generation.GenerationConfig
+ - from_pretrained
+ - from_model_config
+ - save_pretrained
+ - update
+ - validate
+ - get_generation_mode
+
+## GenerationMixin
+
+[[autodoc]] GenerationMixin
+ - generate
+ - compute_transition_scores
+
+## TFGenerationMixin
+
+[[autodoc]] TFGenerationMixin
+ - generate
+ - compute_transition_scores
+
+## FlaxGenerationMixin
+
+[[autodoc]] FlaxGenerationMixin
+ - generate
diff --git a/docs/source/main_classes/tokenizer.md b/docs/source/main_classes/tokenizer.md
new file mode 100644
index 0000000..83d2ae5
--- /dev/null
+++ b/docs/source/main_classes/tokenizer.md
@@ -0,0 +1,104 @@
+
+
+# Tokenizer
+
+A tokenizer is in charge of preparing the inputs for a model. The library contains tokenizers for all the models. Most
+of the tokenizers are available in two flavors: a full python implementation and a "Fast" implementation based on the
+Rust library [๐ค Tokenizers](https://github.com/huggingface/tokenizers). The "Fast" implementations allows:
+
+1. a significant speed-up in particular when doing batched tokenization and
+2. additional methods to map between the original string (character and words) and the token space (e.g. getting the
+ index of the token comprising a given character or the span of characters corresponding to a given token).
+
+The base classes [`PreTrainedTokenizer`] and [`PreTrainedTokenizerFast`]
+implement the common methods for encoding string inputs in model inputs (see below) and instantiating/saving python and
+"Fast" tokenizers either from a local file or directory or from a pretrained tokenizer provided by the library
+(downloaded from HuggingFace's AWS S3 repository). They both rely on
+[`~tokenization_utils_base.PreTrainedTokenizerBase`] that contains the common methods, and
+[`~tokenization_utils_base.SpecialTokensMixin`].
+
+[`PreTrainedTokenizer`] and [`PreTrainedTokenizerFast`] thus implement the main
+methods for using all the tokenizers:
+
+- Tokenizing (splitting strings in sub-word token strings), converting tokens strings to ids and back, and
+ encoding/decoding (i.e., tokenizing and converting to integers).
+- Adding new tokens to the vocabulary in a way that is independent of the underlying structure (BPE, SentencePiece...).
+- Managing special tokens (like mask, beginning-of-sentence, etc.): adding them, assigning them to attributes in the
+ tokenizer for easy access and making sure they are not split during tokenization.
+
+[`BatchEncoding`] holds the output of the
+[`~tokenization_utils_base.PreTrainedTokenizerBase`]'s encoding methods (`__call__`,
+`encode_plus` and `batch_encode_plus`) and is derived from a Python dictionary. When the tokenizer is a pure python
+tokenizer, this class behaves just like a standard python dictionary and holds the various model inputs computed by
+these methods (`input_ids`, `attention_mask`...). When the tokenizer is a "Fast" tokenizer (i.e., backed by
+HuggingFace [tokenizers library](https://github.com/huggingface/tokenizers)), this class provides in addition
+several advanced alignment methods which can be used to map between the original string (character and words) and the
+token space (e.g., getting the index of the token comprising a given character or the span of characters corresponding
+to a given token).
+
+
+# Multimodal Tokenizer
+
+Apart from that each tokenizer can be a "multimodal" tokenizer which means that the tokenizer will hold all relevant special tokens
+as part of tokenizer attributes for easier access. For example, if the tokenizer is loaded from a vision-language model like LLaVA, you will
+be able to access `tokenizer.image_token_id` to obtain the special image token used as a placeholder.
+
+To enable extra special tokens for any type of tokenizer, you have to add the following lines and save the tokenizer. Extra special tokens do not
+have to be modality related and can ne anything that the model often needs access to. In the below code, tokenizer at `output_dir` will have direct access
+to three more special tokens.
+
+```python
+vision_tokenizer = AutoTokenizer.from_pretrained(
+ "llava-hf/llava-1.5-7b-hf",
+ extra_special_tokens={"image_token": "", "boi_token": "", "eoi_token": ""}
+)
+print(vision_tokenizer.image_token, vision_tokenizer.image_token_id)
+("", 32000)
+```
+
+## PreTrainedTokenizer
+
+[[autodoc]] PreTrainedTokenizer
+ - __call__
+ - add_tokens
+ - add_special_tokens
+ - apply_chat_template
+ - batch_decode
+ - decode
+ - encode
+ - push_to_hub
+ - all
+
+## PreTrainedTokenizerFast
+
+The [`PreTrainedTokenizerFast`] depend on the [tokenizers](https://huggingface.co/docs/tokenizers) library. The tokenizers obtained from the ๐ค tokenizers library can be
+loaded very simply into ๐ค transformers. Take a look at the [Using tokenizers from ๐ค tokenizers](../fast_tokenizers) page to understand how this is done.
+
+[[autodoc]] PreTrainedTokenizerFast
+ - __call__
+ - add_tokens
+ - add_special_tokens
+ - apply_chat_template
+ - batch_decode
+ - decode
+ - encode
+ - push_to_hub
+ - all
+
+## BatchEncoding
+
+[[autodoc]] BatchEncoding
diff --git a/docs/source/main_classes/trainer.md b/docs/source/main_classes/trainer.md
new file mode 100644
index 0000000..21ba9ed
--- /dev/null
+++ b/docs/source/main_classes/trainer.md
@@ -0,0 +1,54 @@
+
+
+# Trainer
+
+The [`Trainer`] class provides an API for feature-complete training in PyTorch, and it supports distributed training on multiple GPUs/TPUs, mixed precision for [NVIDIA GPUs](https://nvidia.github.io/apex/), [AMD GPUs](https://rocm.docs.amd.com/en/latest/rocm.html), and [`torch.amp`](https://pytorch.org/docs/stable/amp.html) for PyTorch. [`Trainer`] goes hand-in-hand with the [`TrainingArguments`] class, which offers a wide range of options to customize how a model is trained. Together, these two classes provide a complete training API.
+
+[`Seq2SeqTrainer`] and [`Seq2SeqTrainingArguments`] inherit from the [`Trainer`] and [`TrainingArguments`] classes and they're adapted for training models for sequence-to-sequence tasks such as summarization or translation.
+
+
+
+The [`Trainer`] class is optimized for ๐ค Transformers models and can have surprising behaviors
+when used with other models. When using it with your own model, make sure:
+
+- your model always return tuples or subclasses of [`~utils.ModelOutput`]
+- your model can compute the loss if a `labels` argument is provided and that loss is returned as the first
+ element of the tuple (if your model returns tuples)
+- your model can accept multiple label arguments (use `label_names` in [`TrainingArguments`] to indicate their name to the [`Trainer`]) but none of them should be named `"label"`
+
+
+
+## Trainer[[api-reference]]
+
+[[autodoc]] Trainer
+ - all
+
+## Seq2SeqTrainer
+
+[[autodoc]] Seq2SeqTrainer
+ - evaluate
+ - predict
+
+## TrainingArguments
+
+[[autodoc]] TrainingArguments
+ - all
+
+## Seq2SeqTrainingArguments
+
+[[autodoc]] Seq2SeqTrainingArguments
+ - all
diff --git a/docs/source/quicktour.md b/docs/source/quicktour.md
new file mode 100644
index 0000000..b70cbb9
--- /dev/null
+++ b/docs/source/quicktour.md
@@ -0,0 +1,189 @@
+
+
+# Quicktour
+
+There are many ways to launch and run your code depending on your training environment ([torchrun](https://pytorch.org/docs/stable/elastic/run.html), [DeepSpeed](https://www.deepspeed.ai/), etc.) and available hardware. Accelerate offers a unified interface for launching and training on different distributed setups, allowing you to focus on your PyTorch training code instead of the intricacies of adapting your code to these different setups. This allows you to easily scale your PyTorch code for training and inference on distributed setups with hardware like GPUs and TPUs. Accelerate also provides Big Model Inference to make loading and running inference with really large models that usually don't fit in memory more accessible.
+
+This quicktour introduces the three main features of Accelerate:
+
+* a unified command line launching interface for distributed training scripts
+* a training library for adapting PyTorch training code to run on different distributed setups
+* Big Model Inference
+
+## Unified launch interface
+
+Accelerate automatically selects the appropriate configuration values for any given distributed training framework (DeepSpeed, FSDP, etc.) through a unified configuration file generated from the [`accelerate config`](package_reference/cli#accelerate-config) command. You could also pass the configuration values explicitly to the command line which is helpful in certain situations like if you're using SLURM.
+
+
+But in most cases, you should always run [`accelerate config`](package_reference/cli#accelerate-config) first to help Accelerate learn about your training setup.
+
+```bash
+accelerate config
+```
+
+The [`accelerate config`](package_reference/cli#accelerate-config) command creates and saves a default_config.yaml file in Accelerates cache folder. This file stores the configuration for your training environment, which helps Accelerate correctly launch your training script based on your machine.
+
+After you've configured your environment, you can test your setup with [`accelerate test`](package_reference/cli#accelerate-test), which launches a short script to test the distributed environment.
+
+```bash
+accelerate test
+```
+
+> [!TIP]
+> Add `--config_file` to the `accelerate test` or `accelerate launch` command to specify the location of the configuration file if it is saved in a non-default location like the cache.
+
+Once your environment is setup, launch your training script with [`accelerate launch`](package_reference/cli#accelerate-launch)!
+
+```bash
+accelerate launch path_to_script.py --args_for_the_script
+```
+
+To learn more, check out the [Launch distributed code](basic_tutorials/launch) tutorial for more information about launching your scripts.
+
+We also have a [configuration zoo](https://github.com/huggingface/accelerate/blob/main/examples/config_yaml_templates) which showcases a number of premade **minimal** example configurations for a variety of setups you can run.
+
+## Adapt training code
+
+The next main feature of Accelerate is the [`Accelerator`] class which adapts your PyTorch code to run on different distributed setups.
+
+You only need to add a few lines of code to your training script to enable it to run on multiple GPUs or TPUs.
+
+```diff
++ from accelerate import Accelerator
++ accelerator = Accelerator()
+
++ device = accelerator.device
++ model, optimizer, training_dataloader, scheduler = accelerator.prepare(
++ model, optimizer, training_dataloader, scheduler
++ )
+
+ for batch in training_dataloader:
+ optimizer.zero_grad()
+ inputs, targets = batch
+- inputs = inputs.to(device)
+- targets = targets.to(device)
+ outputs = model(inputs)
+ loss = loss_function(outputs, targets)
++ accelerator.backward(loss)
+ optimizer.step()
+ scheduler.step()
+```
+
+1. Import and instantiate the [`Accelerator`] class at the beginning of your training script. The [`Accelerator`] class initializes everything necessary for distributed training, and it automatically detects your training environment (a single machine with a GPU, a machine with several GPUs, several machines with multiple GPUs or a TPU, etc.) based on how the code was launched.
+
+```python
+from accelerate import Accelerator
+
+accelerator = Accelerator()
+```
+
+2. Remove calls like `.cuda()` on your model and input data. The [`Accelerator`] class automatically places these objects on the appropriate device for you.
+
+> [!WARNING]
+> This step is *optional* but it is considered best practice to allow Accelerate to handle device placement. You could also deactivate automatic device placement by passing `device_placement=False` when initializing the [`Accelerator`]. If you want to explicitly place objects on a device with `.to(device)`, make sure you use `accelerator.device` instead. For example, if you create an optimizer before placing a model on `accelerator.device`, training fails on a TPU.
+
+> [!WARNING]
+> Accelerate does not use non-blocking transfers by default for its automatic device placement, which can result in potentially unwanted CUDA synchronizations. You can enable non-blocking transfers by passing a [`~utils.dataclasses.DataLoaderConfiguration`] with `non_blocking=True` set as the `dataloader_config` when initializing the [`Accelerator`]. As usual, non-blocking transfers will only work if the dataloader also has `pin_memory=True` set. Be wary that using non-blocking transfers from GPU to CPU may cause incorrect results if it results in CPU operations being performed on non-ready tensors.
+
+```py
+device = accelerator.device
+```
+
+3. Pass all relevant PyTorch objects for training (optimizer, model, dataloader(s), learning rate scheduler) to the [`~Accelerator.prepare`] method as soon as they're created. This method wraps the model in a container optimized for your distributed setup, uses Accelerates version of the optimizer and scheduler, and creates a sharded version of your dataloader for distribution across GPUs or TPUs.
+
+```python
+model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ model, optimizer, train_dataloader, lr_scheduler
+)
+```
+
+4. Replace `loss.backward()` with [`~Accelerator.backward`] to use the correct `backward()` method for your training setup.
+
+```py
+accelerator.backward(loss)
+```
+
+Read [Accelerateโs internal mechanisms](concept_guides/internal_mechanism) guide to learn more details about how Accelerate adapts your code.
+
+### Distributed evaluation
+
+To perform distributed evaluation, pass your validation dataloader to the [`~Accelerator.prepare`] method:
+
+```python
+validation_dataloader = accelerator.prepare(validation_dataloader)
+```
+
+Each device in your distributed setup only receives a part of the evaluation data, which means you should group your predictions together with the [`~Accelerator.gather_for_metrics`] method. This method requires all tensors to be the same size on each process, so if your tensors have different sizes on each process (for instance when dynamically padding to the maximum length in a batch), you should use the [`~Accelerator.pad_across_processes`] method to pad you tensor to the largest size across processes. Note that the tensors needs to be 1D and that we concatenate the tensors along the first dimension.
+
+```python
+for inputs, targets in validation_dataloader:
+ predictions = model(inputs)
+ # Gather all predictions and targets
+ all_predictions, all_targets = accelerator.gather_for_metrics((predictions, targets))
+ # Example of use with a *Datasets.Metric*
+ metric.add_batch(all_predictions, all_targets)
+```
+
+For more complex cases (e.g. 2D tensors, don't want to concatenate tensors, dict of 3D tensors), you can pass `use_gather_object=True` in `gather_for_metrics`. This will return the list of objects after gathering. Note that using it with GPU tensors is not well supported and inefficient.
+
+> [!TIP]
+> Data at the end of a dataset may be duplicated so the batch can be equally divided among all workers. The [`~Accelerator.gather_for_metrics`] method automatically removes the duplicated data to calculate a more accurate metric.
+
+## Big Model Inference
+
+Accelerate's Big Model Inference has two main features, [`~accelerate.init_empty_weights`] and [`~accelerate.load_checkpoint_and_dispatch`], to load large models for inference that typically don't fit into memory.
+
+> [!TIP]
+> Take a look at the [Handling big models for inference](concept_guides/big_model_inference) guide for a better understanding of how Big Model Inference works under the hood.
+
+### Empty weights initialization
+
+The [`~accelerate.init_empty_weights`] context manager initializes models of any size by creating a *model skeleton* and moving and placing parameters each time they're created to PyTorch's [**meta**](https://pytorch.org/docs/main/meta.html) device. This way, not all weights are immediately loaded and only a small part of the model is loaded into memory at a time.
+
+For example, loading an empty [Mixtral-8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1) model takes significantly less memory than fully loading the models and weights on the CPU.
+
+```py
+from accelerate import init_empty_weights
+from transformers import AutoConfig, AutoModelForCausalLM
+
+config = AutoConfig.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1")
+with init_empty_weights():
+ model = AutoModelForCausalLM.from_config(config)
+```
+
+### Load and dispatch weights
+
+The [`~accelerate.load_checkpoint_and_dispatch`] function loads full or sharded checkpoints into the empty model, and automatically distribute weights across all available devices.
+
+The `device_map` parameter determines where to place each model layer, and specifiying `"auto"` places them on the GPU first, then the CPU, and finally the hard drive as memory-mapped tensors if there's still not enough memory. Use the `no_split_module_classes` parameter to indicate which modules shouldn't be split across devices (typically those with a residual connection).
+
+```py
+from accelerate import load_checkpoint_and_dispatch
+
+model_checkpoint = "your-local-model-folder"
+model = load_checkpoint_and_dispatch(
+ model, checkpoint=model_checkpoint, device_map="auto", no_split_module_classes=['Block']
+)
+```
+
+## Next steps
+
+Now that you've been introduced to the main Accelerate features, your next steps could include:
+
+* Check out the [tutorials](basic_tutorials/overview) for a gentle walkthrough of Accelerate. This is especially useful if you're new to distributed training and the library.
+* Dive into the [guides](usage_guides/explore) to see how to use Accelerate for specific use-cases.
+* Deepen your conceptual understanding of how Accelerate works internally by reading the [concept guides](concept_guides/internal_mechanism).
+* Look up classes and commands in the [API reference](package_reference/accelerator) to see what parameters and options are available.
diff --git a/examples/README.md b/examples/README.md
new file mode 100644
index 0000000..27938cd
--- /dev/null
+++ b/examples/README.md
@@ -0,0 +1,276 @@
+
+
+# In this folder we showcase various full examples using ๐ค Accelerate
+
+## Simple NLP example
+
+The [nlp_example.py](./nlp_example.py) script is a simple example to train a Bert model on a classification task ([GLUE's MRPC](https://www.microsoft.com/en-us/download/details.aspx?id=52398)).
+
+Prior to running it you should install ๐ค Dataset and ๐ค Transformers:
+
+```bash
+pip install datasets evaluate transformers
+```
+
+The same script can be run in any of the following configurations:
+- single CPU or single GPU
+- multi CPUs
+- multi GPUs (using PyTorch distributed mode)
+- (multi) TPUs
+- fp16 (mixed-precision) or fp32 (normal precision)
+
+To run it in each of these various modes, use the following commands:
+- single CPU:
+ * from a server without GPU
+ ```bash
+ python ./nlp_example.py
+ ```
+ * from any server by passing `cpu=True` to the `Accelerator`.
+ ```bash
+ python ./nlp_example.py --cpu
+ ```
+ * from any server with Accelerate launcher
+ ```bash
+ accelerate launch --cpu ./nlp_example.py
+ ```
+- single GPU:
+ ```bash
+ python ./nlp_example.py # from a server with a GPU
+ ```
+- with fp16 (mixed-precision)
+ * from any server by passing `mixed_precison=fp16` to the `Accelerator`.
+ ```bash
+ python ./nlp_example.py --mixed_precision fp16
+ ```
+ * from any server with Accelerate launcher
+ ```bash
+ accelerate launch --mixed_precision fp16 ./nlp_example.py
+- multi CPUs (requires Open MPI, Intel MPI, or MVAPICH)
+ * With Accelerate config and launcher, execute the following from node 0:
+ ```bash
+ accelerate config # Select to have accelerate launch mpirun
+ accelerate launch ./nlp_example.py # This will run the script on each server
+ ```
+ * With Intel MPI:
+ ```bash
+ export CCL_WORKER_COUNT=1
+ export MASTER_ADDR=xxx.xxx.xxx.xxx #node0 ip
+ mpirun -f hostfile -n 16 -ppn 4 python ./nlp_example.py
+ ```
+- multi GPUs (using PyTorch distributed mode)
+ * With Accelerate config and launcher
+ ```bash
+ accelerate config # This will create a config file on your server
+ accelerate launch ./nlp_example.py # This will run the script on your server
+ ```
+ * With traditional PyTorch launcher (`python -m torch.distributed.run` can be used instead of `torchrun`)
+ ```bash
+ torchrun --nproc_per_node 2 ./nlp_example.py
+ ```
+- multi GPUs, multi node (several machines, using PyTorch distributed mode)
+ * With Accelerate config and launcher, on each machine:
+ ```bash
+ accelerate config # This will create a config file on each server
+ accelerate launch ./nlp_example.py # This will run the script on each server
+ ```
+ * With PyTorch launcher only (`python -m torch.distributed.run` can be used instead of `torchrun`). Run this command on each node:
+ ```bash
+ torchrun \ # python -m torch.distributed.run
+ --nproc_per_node 2 \
+ --nnodes 2 \
+ --rdzv_id 2299 \ # A unique job id
+ --rdzv_backend c10d \
+ --rdzv_endpoint master_node_ip_address:29500 \
+ ./nlp_example.py
+ ```
+- (multi) TPUs
+ * With Accelerate config and launcher
+ ```bash
+ accelerate config # This will create a config file on your TPU server
+ accelerate launch ./nlp_example.py # This will run the script on each server
+ ```
+ * In PyTorch:
+ Add an `xmp.spawn` line in your script as you usually do.
+
+
+## Simple vision example
+
+The [cv_example.py](./cv_example.py) script is a simple example to fine-tune a ResNet-50 on a classification task ([Ofxord-IIT Pet Dataset](https://www.robots.ox.ac.uk/~vgg/data/pets/)).
+
+The same script can be run in any of the following configurations:
+- single CPU or single GPU
+- multi CPUs
+- multi GPUs (using PyTorch distributed mode)
+- (multi) TPUs
+- fp16 (mixed-precision) or fp32 (normal precision)
+
+Prior to running it you should install timm and torchvision:
+
+```bash
+pip install timm torchvision
+```
+
+and you should download the data with the following commands:
+
+```bash
+wget https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz
+tar -xzf images.tar.gz
+```
+
+To run it in each of these various modes, use the following commands:
+- single CPU:
+ * from a server without GPU
+ ```bash
+ python ./cv_example.py --data_dir path_to_data
+ ```
+ * from any server by passing `cpu=True` to the `Accelerator`.
+ ```bash
+ python ./cv_example.py --data_dir path_to_data --cpu
+ ```
+ * from any server with Accelerate launcher
+ ```bash
+ accelerate launch --cpu ./cv_example.py --data_dir path_to_data
+ ```
+- single GPU:
+ ```bash
+ python ./cv_example.py # from a server with a GPU
+ ```
+- with fp16 (mixed-precision)
+ * from any server by passing `mixed_precison=fp16` to the `Accelerator`.
+ ```bash
+ python ./cv_example.py --data_dir path_to_data --mixed_precison fp16
+ ```
+ * from any server with Accelerate launcher
+ ```bash
+ accelerate launch --mixed_precison fp16 ./cv_example.py --data_dir path_to_data
+- multi CPUs (requires Open MPI, Intel MPI, or MVAPICH)
+ * With Accelerate config and launcher, run the following from node 0:
+ ```bash
+ accelerate config --config_file config.yaml # Select to have accelerate launch mpirun
+ accelerate launch ./cv_example.py --data_dir path_to_data # This will run the script on each server
+ ```
+ * With Intel MPI, execute mpirun from node 0:
+ ```bash
+ export CCL_WORKER_COUNT=1
+ export MASTER_ADDR=xxx.xxx.xxx.xxx #node0 ip
+ mpirun -f hostfile -n 16 -ppn 4 python ./cv_example.py --data_dir path_to_data
+ ```
+- multi GPUs (using PyTorch distributed mode)
+ * With Accelerate config and launcher
+ ```bash
+ accelerate config --config_file config.yaml # This will create a config file on your server to `config.yaml`
+ accelerate launch --config_file config.yaml ./cv_example.py --data_dir path_to_data # This will run the script on your server
+ ```
+ * With traditional PyTorch launcher (`python -m torch.distributed.run` can be used instead of `torchrun`)
+ ```bash
+ torchrun --nproc_per_node 2 ./cv_example.py --data_dir path_to_data
+ ```
+- multi GPUs, multi node (several machines, using PyTorch distributed mode)
+ * With Accelerate config and launcher, on each machine:
+ ```bash
+ accelerate config --config_file config.yaml # This will create a config file on your server to `config.yaml`
+ accelerate launch --config_file config.yaml ./cv_example.py --data_dir path_to_data # This will run the script on each server
+ ```
+ * With PyTorch launcher only (`python -m torch.distributed.run` can be used instead of `torchrun`). Run this command on each node:
+ ```bash
+ torchrun \ # python -m torch.distributed.run
+ --nproc_per_node 2 \
+ --nnodes 2 \
+ --rdzv_id 2299 \ # A unique job id
+ --rdzv_backend c10d \
+ --rdzv_endpoint master_node_ip_address:29500 \
+ ./cv_example.py --data_dir path_to_data
+ ```
+- (multi) TPUs
+ * With Accelerate config and launcher
+ ```bash
+ accelerate config --config_file config.yaml # This will create a config file on your server to `config.yaml`
+ accelerate launch --config_file config.yaml ./cv_example.py --data_dir path_to_data # This will run the script on each server
+ ```
+ * In PyTorch:
+ Add an `xmp.spawn` line in your script as you usually do.
+
+### Simple vision example (GANs)
+
+- [huggan project](https://github.com/huggingface/community-events/tree/main/huggan)
+
+
+### Using AWS SageMaker integration
+- [Examples showcasing AWS SageMaker integration of ๐ค Accelerate.](https://github.com/pacman100/accelerate-aws-sagemaker)
+
+## Configuration zoo
+In [/config_yaml_templates](./config_yaml_templates/) we have a variety of *minimal* `config.yaml` templates and examples to help you learn
+how to create your own configuration files depending on the scenario.
+
+## SLURM Scripts
+In [/slurm/submit_multigpu.sh](./slurm/submit_multigpu.sh) and [/slurm/submit_multinode.sh](./slurm/submit_multinode.sh) we present two scripts for running the examples on a machine with [SLURM](https://slurm.schedmd.com/documentation.html) workload manager.
+
+In [/slurm/submit_multigpu.sh](./slurm/submit_multigpu.sh) the only parameter in the launcher that needs to be modified is `--num_processes`, which determines the number of GPUs we will use. In this case, using the environment variable `$SLURM_GPUS`, we indicate that we want to utilize all the GPUs available on the node we have requested.
+
+In [/slurm/submit_multinode.sh](./slurm/submit_multinode.sh) we must specify the number of nodes that will be part of the training (`--num_machines`), how many GPUs we will use in total (`--num_processes`), the [`backend`](https://pytorch.org/docs/stable/elastic/run.html#note-on-rendezvous-backend), `--main_process_ip` which will be the address the master node and the `--main_process_port`.
+
+In [/slurm/submit_multicpu.sh](./slurm/submit_multicpu.sh) we must specify the number of nodes that will be part of the training (`--num_machines`), how many CPU processes we will use in total (`--num_processes`), the [`backend`](https://pytorch.org/docs/stable/elastic/run.html#note-on-rendezvous-backend), `--main_process_ip` which will be the address the master node and the `--main_process_port`. `mpirun_hostfile` specifies to run the job using MPIRun.
+
+In both scripts, we run `activateEnviroment.sh` at the beginning. This script should contain the necessary instructions to initialize the environment for execution. Below, we show an example that loads the necessary libraries ([Environment modules](https://github.com/cea-hpc/modules)), activates the Python environment, and sets up various environment variables, most of them to run the scripts in offline mode in case we don't have internet connection from the cluster.
+
+```bash
+# activateEnvironment.sh
+module purge
+module load anaconda3/2020.02 cuda/10.2 cudnn/8.0.5 nccl/2.9.9 arrow/7.0.0 openmpi
+source activate /home/nct01/nct01328/pytorch_antoni_local
+
+export HF_HOME=/gpfs/projects/nct01/nct01328/
+export HF_LOCAL_HOME=/gpfs/projects/nct01/nct01328/HF_LOCAL
+export HF_DATASETS_OFFLINE=1
+export TRANSFORMERS_OFFLINE=1
+export PYTHONPATH=/home/nct01/nct01328/transformers-in-supercomputers:$PYTHONPATH
+export GPUS_PER_NODE=4
+```
+
+## Simple Multi-GPU Hardware Launcher (using an external platform)
+
+[multigpu_remote_launcher.py](./multigpu_remote_launcher.py) is a minimal script that demonstrates launching accelerate
+on multiple remote GPUs, and with automatic hardware environment and dependency setup for reproducibility. You can
+easily customize the training function used, training arguments, hyperparameters, and type of compute hardware, and then
+run the script to automatically launch multi GPU training on remote hardware.
+
+This script uses [Runhouse](https://github.com/run-house/runhouse) to launch on self-hosted hardware (e.g. in your own
+cloud account or on-premise cluster) but there are other options for running remotely as well. Runhouse can be installed
+with `pip install runhouse`, and you can refer to
+[hardware setup](https://runhouse-docs.readthedocs-hosted.com/en/latest/api/python/cluster.html#hardware-setup)
+for hardware setup instructions, or this
+[Colab tutorial](https://colab.research.google.com/drive/1qVwYyLTCPYPSdz9ZX7BZl9Qm0A3j7RJe) for a more in-depth walkthrough.
+
+## Finer Examples
+
+While the first two scripts are extremely barebones when it comes to what you can do with accelerate, more advanced features are documented in two other locations.
+
+### `by_feature` examples
+
+These scripts are *individual* examples highlighting one particular feature or use-case within Accelerate. They all stem from the [nlp_example.py](./nlp_example.py) script, and any changes or modifications is denoted with a `# New Code #` comment.
+
+Read the README.md file located in the `by_feature` folder for more information.
+
+### `complete_*` examples
+
+These two scripts contain *every* single feature currently available in Accelerate in one place, as one giant script.
+
+New arguments that can be passed include:
+
+- `checkpointing_steps`, whether the various states should be saved at the end of every `n` steps, or `"epoch"` for each epoch. States are then saved to folders named `step_{n}` or `epoch_{n}`
+- `resume_from_checkpoint`, should be used if you want to resume training off of a previous call to the script and passed a `checkpointing_steps` to it.
+- `with_tracking`, should be used if you want to log the training run using all available experiment trackers in your environment. Currently supported trackers include TensorBoard, Weights and Biases, and CometML.
diff --git a/examples/requirements.txt b/examples/requirements.txt
new file mode 100644
index 0000000..fd571f2
--- /dev/null
+++ b/examples/requirements.txt
@@ -0,0 +1,5 @@
+accelerate # used to be installed in Amazon SageMaker environment
+evaluate
+datasets==2.3.2
+schedulefree
+huggingface_hub>=0.20.0
diff --git a/poetry.lock b/poetry.lock
new file mode 100644
index 0000000..facbf6a
--- /dev/null
+++ b/poetry.lock
@@ -0,0 +1,891 @@
+# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
+
+[[package]]
+name = "certifi"
+version = "2024.8.30"
+description = "Python package for providing Mozilla's CA Bundle."
+optional = false
+python-versions = ">=3.6"
+files = [
+ {file = "certifi-2024.8.30-py3-none-any.whl", hash = "sha256:922820b53db7a7257ffbda3f597266d435245903d80737e34f8a45ff3e3230d8"},
+ {file = "certifi-2024.8.30.tar.gz", hash = "sha256:bec941d2aa8195e248a60b31ff9f0558284cf01a52591ceda73ea9afffd69fd9"},
+]
+
+[[package]]
+name = "charset-normalizer"
+version = "3.4.0"
+description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet."
+optional = false
+python-versions = ">=3.7.0"
+files = [
+ {file = "charset_normalizer-3.4.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:4f9fc98dad6c2eaa32fc3af1417d95b5e3d08aff968df0cd320066def971f9a6"},
+ {file = "charset_normalizer-3.4.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0de7b687289d3c1b3e8660d0741874abe7888100efe14bd0f9fd7141bcbda92b"},
+ {file = "charset_normalizer-3.4.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5ed2e36c3e9b4f21dd9422f6893dec0abf2cca553af509b10cd630f878d3eb99"},
+ {file = "charset_normalizer-3.4.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:40d3ff7fc90b98c637bda91c89d51264a3dcf210cade3a2c6f838c7268d7a4ca"},
+ {file = "charset_normalizer-3.4.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1110e22af8ca26b90bd6364fe4c763329b0ebf1ee213ba32b68c73de5752323d"},
+ {file = "charset_normalizer-3.4.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:86f4e8cca779080f66ff4f191a685ced73d2f72d50216f7112185dc02b90b9b7"},
+ {file = "charset_normalizer-3.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7f683ddc7eedd742e2889d2bfb96d69573fde1d92fcb811979cdb7165bb9c7d3"},
+ {file = "charset_normalizer-3.4.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:27623ba66c183eca01bf9ff833875b459cad267aeeb044477fedac35e19ba907"},
+ {file = "charset_normalizer-3.4.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:f606a1881d2663630ea5b8ce2efe2111740df4b687bd78b34a8131baa007f79b"},
+ {file = "charset_normalizer-3.4.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:0b309d1747110feb25d7ed6b01afdec269c647d382c857ef4663bbe6ad95a912"},
+ {file = "charset_normalizer-3.4.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:136815f06a3ae311fae551c3df1f998a1ebd01ddd424aa5603a4336997629e95"},
+ {file = "charset_normalizer-3.4.0-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:14215b71a762336254351b00ec720a8e85cada43b987da5a042e4ce3e82bd68e"},
+ {file = "charset_normalizer-3.4.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:79983512b108e4a164b9c8d34de3992f76d48cadc9554c9e60b43f308988aabe"},
+ {file = "charset_normalizer-3.4.0-cp310-cp310-win32.whl", hash = "sha256:c94057af19bc953643a33581844649a7fdab902624d2eb739738a30e2b3e60fc"},
+ {file = "charset_normalizer-3.4.0-cp310-cp310-win_amd64.whl", hash = "sha256:55f56e2ebd4e3bc50442fbc0888c9d8c94e4e06a933804e2af3e89e2f9c1c749"},
+ {file = "charset_normalizer-3.4.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:0d99dd8ff461990f12d6e42c7347fd9ab2532fb70e9621ba520f9e8637161d7c"},
+ {file = "charset_normalizer-3.4.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c57516e58fd17d03ebe67e181a4e4e2ccab1168f8c2976c6a334d4f819fe5944"},
+ {file = "charset_normalizer-3.4.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6dba5d19c4dfab08e58d5b36304b3f92f3bd5d42c1a3fa37b5ba5cdf6dfcbcee"},
+ {file = "charset_normalizer-3.4.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bf4475b82be41b07cc5e5ff94810e6a01f276e37c2d55571e3fe175e467a1a1c"},
+ {file = "charset_normalizer-3.4.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ce031db0408e487fd2775d745ce30a7cd2923667cf3b69d48d219f1d8f5ddeb6"},
+ {file = "charset_normalizer-3.4.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8ff4e7cdfdb1ab5698e675ca622e72d58a6fa2a8aa58195de0c0061288e6e3ea"},
+ {file = "charset_normalizer-3.4.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3710a9751938947e6327ea9f3ea6332a09bf0ba0c09cae9cb1f250bd1f1549bc"},
+ {file = "charset_normalizer-3.4.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:82357d85de703176b5587dbe6ade8ff67f9f69a41c0733cf2425378b49954de5"},
+ {file = "charset_normalizer-3.4.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:47334db71978b23ebcf3c0f9f5ee98b8d65992b65c9c4f2d34c2eaf5bcaf0594"},
+ {file = "charset_normalizer-3.4.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:8ce7fd6767a1cc5a92a639b391891bf1c268b03ec7e021c7d6d902285259685c"},
+ {file = "charset_normalizer-3.4.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:f1a2f519ae173b5b6a2c9d5fa3116ce16e48b3462c8b96dfdded11055e3d6365"},
+ {file = "charset_normalizer-3.4.0-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:63bc5c4ae26e4bc6be6469943b8253c0fd4e4186c43ad46e713ea61a0ba49129"},
+ {file = "charset_normalizer-3.4.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:bcb4f8ea87d03bc51ad04add8ceaf9b0f085ac045ab4d74e73bbc2dc033f0236"},
+ {file = "charset_normalizer-3.4.0-cp311-cp311-win32.whl", hash = "sha256:9ae4ef0b3f6b41bad6366fb0ea4fc1d7ed051528e113a60fa2a65a9abb5b1d99"},
+ {file = "charset_normalizer-3.4.0-cp311-cp311-win_amd64.whl", hash = "sha256:cee4373f4d3ad28f1ab6290684d8e2ebdb9e7a1b74fdc39e4c211995f77bec27"},
+ {file = "charset_normalizer-3.4.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:0713f3adb9d03d49d365b70b84775d0a0d18e4ab08d12bc46baa6132ba78aaf6"},
+ {file = "charset_normalizer-3.4.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:de7376c29d95d6719048c194a9cf1a1b0393fbe8488a22008610b0361d834ecf"},
+ {file = "charset_normalizer-3.4.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4a51b48f42d9358460b78725283f04bddaf44a9358197b889657deba38f329db"},
+ {file = "charset_normalizer-3.4.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b295729485b06c1a0683af02a9e42d2caa9db04a373dc38a6a58cdd1e8abddf1"},
+ {file = "charset_normalizer-3.4.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ee803480535c44e7f5ad00788526da7d85525cfefaf8acf8ab9a310000be4b03"},
+ {file = "charset_normalizer-3.4.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3d59d125ffbd6d552765510e3f31ed75ebac2c7470c7274195b9161a32350284"},
+ {file = "charset_normalizer-3.4.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8cda06946eac330cbe6598f77bb54e690b4ca93f593dee1568ad22b04f347c15"},
+ {file = "charset_normalizer-3.4.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:07afec21bbbbf8a5cc3651aa96b980afe2526e7f048fdfb7f1014d84acc8b6d8"},
+ {file = "charset_normalizer-3.4.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:6b40e8d38afe634559e398cc32b1472f376a4099c75fe6299ae607e404c033b2"},
+ {file = "charset_normalizer-3.4.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:b8dcd239c743aa2f9c22ce674a145e0a25cb1566c495928440a181ca1ccf6719"},
+ {file = "charset_normalizer-3.4.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:84450ba661fb96e9fd67629b93d2941c871ca86fc38d835d19d4225ff946a631"},
+ {file = "charset_normalizer-3.4.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:44aeb140295a2f0659e113b31cfe92c9061622cadbc9e2a2f7b8ef6b1e29ef4b"},
+ {file = "charset_normalizer-3.4.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:1db4e7fefefd0f548d73e2e2e041f9df5c59e178b4c72fbac4cc6f535cfb1565"},
+ {file = "charset_normalizer-3.4.0-cp312-cp312-win32.whl", hash = "sha256:5726cf76c982532c1863fb64d8c6dd0e4c90b6ece9feb06c9f202417a31f7dd7"},
+ {file = "charset_normalizer-3.4.0-cp312-cp312-win_amd64.whl", hash = "sha256:b197e7094f232959f8f20541ead1d9862ac5ebea1d58e9849c1bf979255dfac9"},
+ {file = "charset_normalizer-3.4.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:dd4eda173a9fcccb5f2e2bd2a9f423d180194b1bf17cf59e3269899235b2a114"},
+ {file = "charset_normalizer-3.4.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e9e3c4c9e1ed40ea53acf11e2a386383c3304212c965773704e4603d589343ed"},
+ {file = "charset_normalizer-3.4.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:92a7e36b000bf022ef3dbb9c46bfe2d52c047d5e3f3343f43204263c5addc250"},
+ {file = "charset_normalizer-3.4.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:54b6a92d009cbe2fb11054ba694bc9e284dad30a26757b1e372a1fdddaf21920"},
+ {file = "charset_normalizer-3.4.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1ffd9493de4c922f2a38c2bf62b831dcec90ac673ed1ca182fe11b4d8e9f2a64"},
+ {file = "charset_normalizer-3.4.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:35c404d74c2926d0287fbd63ed5d27eb911eb9e4a3bb2c6d294f3cfd4a9e0c23"},
+ {file = "charset_normalizer-3.4.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4796efc4faf6b53a18e3d46343535caed491776a22af773f366534056c4e1fbc"},
+ {file = "charset_normalizer-3.4.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e7fdd52961feb4c96507aa649550ec2a0d527c086d284749b2f582f2d40a2e0d"},
+ {file = "charset_normalizer-3.4.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:92db3c28b5b2a273346bebb24857fda45601aef6ae1c011c0a997106581e8a88"},
+ {file = "charset_normalizer-3.4.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:ab973df98fc99ab39080bfb0eb3a925181454d7c3ac8a1e695fddfae696d9e90"},
+ {file = "charset_normalizer-3.4.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:4b67fdab07fdd3c10bb21edab3cbfe8cf5696f453afce75d815d9d7223fbe88b"},
+ {file = "charset_normalizer-3.4.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:aa41e526a5d4a9dfcfbab0716c7e8a1b215abd3f3df5a45cf18a12721d31cb5d"},
+ {file = "charset_normalizer-3.4.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:ffc519621dce0c767e96b9c53f09c5d215578e10b02c285809f76509a3931482"},
+ {file = "charset_normalizer-3.4.0-cp313-cp313-win32.whl", hash = "sha256:f19c1585933c82098c2a520f8ec1227f20e339e33aca8fa6f956f6691b784e67"},
+ {file = "charset_normalizer-3.4.0-cp313-cp313-win_amd64.whl", hash = "sha256:707b82d19e65c9bd28b81dde95249b07bf9f5b90ebe1ef17d9b57473f8a64b7b"},
+ {file = "charset_normalizer-3.4.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:dbe03226baf438ac4fda9e2d0715022fd579cb641c4cf639fa40d53b2fe6f3e2"},
+ {file = "charset_normalizer-3.4.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dd9a8bd8900e65504a305bf8ae6fa9fbc66de94178c420791d0293702fce2df7"},
+ {file = "charset_normalizer-3.4.0-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b8831399554b92b72af5932cdbbd4ddc55c55f631bb13ff8fe4e6536a06c5c51"},
+ {file = "charset_normalizer-3.4.0-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a14969b8691f7998e74663b77b4c36c0337cb1df552da83d5c9004a93afdb574"},
+ {file = "charset_normalizer-3.4.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dcaf7c1524c0542ee2fc82cc8ec337f7a9f7edee2532421ab200d2b920fc97cf"},
+ {file = "charset_normalizer-3.4.0-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:425c5f215d0eecee9a56cdb703203dda90423247421bf0d67125add85d0c4455"},
+ {file = "charset_normalizer-3.4.0-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:d5b054862739d276e09928de37c79ddeec42a6e1bfc55863be96a36ba22926f6"},
+ {file = "charset_normalizer-3.4.0-cp37-cp37m-musllinux_1_2_i686.whl", hash = "sha256:f3e73a4255342d4eb26ef6df01e3962e73aa29baa3124a8e824c5d3364a65748"},
+ {file = "charset_normalizer-3.4.0-cp37-cp37m-musllinux_1_2_ppc64le.whl", hash = "sha256:2f6c34da58ea9c1a9515621f4d9ac379871a8f21168ba1b5e09d74250de5ad62"},
+ {file = "charset_normalizer-3.4.0-cp37-cp37m-musllinux_1_2_s390x.whl", hash = "sha256:f09cb5a7bbe1ecae6e87901a2eb23e0256bb524a79ccc53eb0b7629fbe7677c4"},
+ {file = "charset_normalizer-3.4.0-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:0099d79bdfcf5c1f0c2c72f91516702ebf8b0b8ddd8905f97a8aecf49712c621"},
+ {file = "charset_normalizer-3.4.0-cp37-cp37m-win32.whl", hash = "sha256:9c98230f5042f4945f957d006edccc2af1e03ed5e37ce7c373f00a5a4daa6149"},
+ {file = "charset_normalizer-3.4.0-cp37-cp37m-win_amd64.whl", hash = "sha256:62f60aebecfc7f4b82e3f639a7d1433a20ec32824db2199a11ad4f5e146ef5ee"},
+ {file = "charset_normalizer-3.4.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:af73657b7a68211996527dbfeffbb0864e043d270580c5aef06dc4b659a4b578"},
+ {file = "charset_normalizer-3.4.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:cab5d0b79d987c67f3b9e9c53f54a61360422a5a0bc075f43cab5621d530c3b6"},
+ {file = "charset_normalizer-3.4.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:9289fd5dddcf57bab41d044f1756550f9e7cf0c8e373b8cdf0ce8773dc4bd417"},
+ {file = "charset_normalizer-3.4.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b493a043635eb376e50eedf7818f2f322eabbaa974e948bd8bdd29eb7ef2a51"},
+ {file = "charset_normalizer-3.4.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9fa2566ca27d67c86569e8c85297aaf413ffab85a8960500f12ea34ff98e4c41"},
+ {file = "charset_normalizer-3.4.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a8e538f46104c815be19c975572d74afb53f29650ea2025bbfaef359d2de2f7f"},
+ {file = "charset_normalizer-3.4.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6fd30dc99682dc2c603c2b315bded2799019cea829f8bf57dc6b61efde6611c8"},
+ {file = "charset_normalizer-3.4.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2006769bd1640bdf4d5641c69a3d63b71b81445473cac5ded39740a226fa88ab"},
+ {file = "charset_normalizer-3.4.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:dc15e99b2d8a656f8e666854404f1ba54765871104e50c8e9813af8a7db07f12"},
+ {file = "charset_normalizer-3.4.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:ab2e5bef076f5a235c3774b4f4028a680432cded7cad37bba0fd90d64b187d19"},
+ {file = "charset_normalizer-3.4.0-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:4ec9dd88a5b71abfc74e9df5ebe7921c35cbb3b641181a531ca65cdb5e8e4dea"},
+ {file = "charset_normalizer-3.4.0-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:43193c5cda5d612f247172016c4bb71251c784d7a4d9314677186a838ad34858"},
+ {file = "charset_normalizer-3.4.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:aa693779a8b50cd97570e5a0f343538a8dbd3e496fa5dcb87e29406ad0299654"},
+ {file = "charset_normalizer-3.4.0-cp38-cp38-win32.whl", hash = "sha256:7706f5850360ac01d80c89bcef1640683cc12ed87f42579dab6c5d3ed6888613"},
+ {file = "charset_normalizer-3.4.0-cp38-cp38-win_amd64.whl", hash = "sha256:c3e446d253bd88f6377260d07c895816ebf33ffffd56c1c792b13bff9c3e1ade"},
+ {file = "charset_normalizer-3.4.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:980b4f289d1d90ca5efcf07958d3eb38ed9c0b7676bf2831a54d4f66f9c27dfa"},
+ {file = "charset_normalizer-3.4.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:f28f891ccd15c514a0981f3b9db9aa23d62fe1a99997512b0491d2ed323d229a"},
+ {file = "charset_normalizer-3.4.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a8aacce6e2e1edcb6ac625fb0f8c3a9570ccc7bfba1f63419b3769ccf6a00ed0"},
+ {file = "charset_normalizer-3.4.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bd7af3717683bea4c87acd8c0d3d5b44d56120b26fd3f8a692bdd2d5260c620a"},
+ {file = "charset_normalizer-3.4.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5ff2ed8194587faf56555927b3aa10e6fb69d931e33953943bc4f837dfee2242"},
+ {file = "charset_normalizer-3.4.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e91f541a85298cf35433bf66f3fab2a4a2cff05c127eeca4af174f6d497f0d4b"},
+ {file = "charset_normalizer-3.4.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:309a7de0a0ff3040acaebb35ec45d18db4b28232f21998851cfa709eeff49d62"},
+ {file = "charset_normalizer-3.4.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:285e96d9d53422efc0d7a17c60e59f37fbf3dfa942073f666db4ac71e8d726d0"},
+ {file = "charset_normalizer-3.4.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:5d447056e2ca60382d460a604b6302d8db69476fd2015c81e7c35417cfabe4cd"},
+ {file = "charset_normalizer-3.4.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:20587d20f557fe189b7947d8e7ec5afa110ccf72a3128d61a2a387c3313f46be"},
+ {file = "charset_normalizer-3.4.0-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:130272c698667a982a5d0e626851ceff662565379baf0ff2cc58067b81d4f11d"},
+ {file = "charset_normalizer-3.4.0-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:ab22fbd9765e6954bc0bcff24c25ff71dcbfdb185fcdaca49e81bac68fe724d3"},
+ {file = "charset_normalizer-3.4.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:7782afc9b6b42200f7362858f9e73b1f8316afb276d316336c0ec3bd73312742"},
+ {file = "charset_normalizer-3.4.0-cp39-cp39-win32.whl", hash = "sha256:2de62e8801ddfff069cd5c504ce3bc9672b23266597d4e4f50eda28846c322f2"},
+ {file = "charset_normalizer-3.4.0-cp39-cp39-win_amd64.whl", hash = "sha256:95c3c157765b031331dd4db3c775e58deaee050a3042fcad72cbc4189d7c8dca"},
+ {file = "charset_normalizer-3.4.0-py3-none-any.whl", hash = "sha256:fe9f97feb71aa9896b81973a7bbada8c49501dc73e58a10fcef6663af95e5079"},
+ {file = "charset_normalizer-3.4.0.tar.gz", hash = "sha256:223217c3d4f82c3ac5e29032b3f1c2eb0fb591b72161f86d93f5719079dae93e"},
+]
+
+[[package]]
+name = "colorama"
+version = "0.4.6"
+description = "Cross-platform colored terminal text."
+optional = false
+python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
+files = [
+ {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"},
+ {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
+]
+
+[[package]]
+name = "exceptiongroup"
+version = "1.2.2"
+description = "Backport of PEP 654 (exception groups)"
+optional = true
+python-versions = ">=3.7"
+files = [
+ {file = "exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b"},
+ {file = "exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc"},
+]
+
+[package.extras]
+test = ["pytest (>=6)"]
+
+[[package]]
+name = "filelock"
+version = "3.16.1"
+description = "A platform independent file lock."
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "filelock-3.16.1-py3-none-any.whl", hash = "sha256:2082e5703d51fbf98ea75855d9d5527e33d8ff23099bec374a134febee6946b0"},
+ {file = "filelock-3.16.1.tar.gz", hash = "sha256:c249fbfcd5db47e5e2d6d62198e565475ee65e4831e2561c8e313fa7eb961435"},
+]
+
+[package.extras]
+docs = ["furo (>=2024.8.6)", "sphinx (>=8.0.2)", "sphinx-autodoc-typehints (>=2.4.1)"]
+testing = ["covdefaults (>=2.3)", "coverage (>=7.6.1)", "diff-cover (>=9.2)", "pytest (>=8.3.3)", "pytest-asyncio (>=0.24)", "pytest-cov (>=5)", "pytest-mock (>=3.14)", "pytest-timeout (>=2.3.1)", "virtualenv (>=20.26.4)"]
+typing = ["typing-extensions (>=4.12.2)"]
+
+[[package]]
+name = "fsspec"
+version = "2024.10.0"
+description = "File-system specification"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "fsspec-2024.10.0-py3-none-any.whl", hash = "sha256:03b9a6785766a4de40368b88906366755e2819e758b83705c88cd7cb5fe81871"},
+ {file = "fsspec-2024.10.0.tar.gz", hash = "sha256:eda2d8a4116d4f2429db8550f2457da57279247dd930bb12f821b58391359493"},
+]
+
+[package.extras]
+abfs = ["adlfs"]
+adl = ["adlfs"]
+arrow = ["pyarrow (>=1)"]
+dask = ["dask", "distributed"]
+dev = ["pre-commit", "ruff"]
+doc = ["numpydoc", "sphinx", "sphinx-design", "sphinx-rtd-theme", "yarl"]
+dropbox = ["dropbox", "dropboxdrivefs", "requests"]
+full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "dask", "distributed", "dropbox", "dropboxdrivefs", "fusepy", "gcsfs", "libarchive-c", "ocifs", "panel", "paramiko", "pyarrow (>=1)", "pygit2", "requests", "s3fs", "smbprotocol", "tqdm"]
+fuse = ["fusepy"]
+gcs = ["gcsfs"]
+git = ["pygit2"]
+github = ["requests"]
+gs = ["gcsfs"]
+gui = ["panel"]
+hdfs = ["pyarrow (>=1)"]
+http = ["aiohttp (!=4.0.0a0,!=4.0.0a1)"]
+libarchive = ["libarchive-c"]
+oci = ["ocifs"]
+s3 = ["s3fs"]
+sftp = ["paramiko"]
+smb = ["smbprotocol"]
+ssh = ["paramiko"]
+test = ["aiohttp (!=4.0.0a0,!=4.0.0a1)", "numpy", "pytest", "pytest-asyncio (!=0.22.0)", "pytest-benchmark", "pytest-cov", "pytest-mock", "pytest-recording", "pytest-rerunfailures", "requests"]
+test-downstream = ["aiobotocore (>=2.5.4,<3.0.0)", "dask-expr", "dask[dataframe,test]", "moto[server] (>4,<5)", "pytest-timeout", "xarray"]
+test-full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "cloudpickle", "dask", "distributed", "dropbox", "dropboxdrivefs", "fastparquet", "fusepy", "gcsfs", "jinja2", "kerchunk", "libarchive-c", "lz4", "notebook", "numpy", "ocifs", "pandas", "panel", "paramiko", "pyarrow", "pyarrow (>=1)", "pyftpdlib", "pygit2", "pytest", "pytest-asyncio (!=0.22.0)", "pytest-benchmark", "pytest-cov", "pytest-mock", "pytest-recording", "pytest-rerunfailures", "python-snappy", "requests", "smbprotocol", "tqdm", "urllib3", "zarr", "zstandard"]
+tqdm = ["tqdm"]
+
+[[package]]
+name = "huggingface-hub"
+version = "0.26.3"
+description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub"
+optional = false
+python-versions = ">=3.8.0"
+files = [
+ {file = "huggingface_hub-0.26.3-py3-none-any.whl", hash = "sha256:e66aa99e569c2d5419240a9e553ad07245a5b1300350bfbc5a4945cf7432991b"},
+ {file = "huggingface_hub-0.26.3.tar.gz", hash = "sha256:90e1fe62ffc26757a073aaad618422b899ccf9447c2bba8c902a90bef5b42e1d"},
+]
+
+[package.dependencies]
+filelock = "*"
+fsspec = ">=2023.5.0"
+packaging = ">=20.9"
+pyyaml = ">=5.1"
+requests = "*"
+tqdm = ">=4.42.1"
+typing-extensions = ">=3.7.4.3"
+
+[package.extras]
+all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio (>=4.0.0)", "jedi", "libcst (==1.4.0)", "mypy (==1.5.1)", "numpy", "pytest (>=8.1.1,<8.2.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.5.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"]
+cli = ["InquirerPy (==0.3.4)"]
+dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio (>=4.0.0)", "jedi", "libcst (==1.4.0)", "mypy (==1.5.1)", "numpy", "pytest (>=8.1.1,<8.2.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.5.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"]
+fastai = ["fastai (>=2.4)", "fastcore (>=1.3.27)", "toml"]
+hf-transfer = ["hf-transfer (>=0.1.4)"]
+inference = ["aiohttp"]
+quality = ["libcst (==1.4.0)", "mypy (==1.5.1)", "ruff (>=0.5.0)"]
+tensorflow = ["graphviz", "pydot", "tensorflow"]
+tensorflow-testing = ["keras (<3.0)", "tensorflow"]
+testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio (>=4.0.0)", "jedi", "numpy", "pytest (>=8.1.1,<8.2.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"]
+torch = ["safetensors[torch]", "torch"]
+typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)"]
+
+[[package]]
+name = "idna"
+version = "3.10"
+description = "Internationalized Domain Names in Applications (IDNA)"
+optional = false
+python-versions = ">=3.6"
+files = [
+ {file = "idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3"},
+ {file = "idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9"},
+]
+
+[package.extras]
+all = ["flake8 (>=7.1.1)", "mypy (>=1.11.2)", "pytest (>=8.3.2)", "ruff (>=0.6.2)"]
+
+[[package]]
+name = "iniconfig"
+version = "2.0.0"
+description = "brain-dead simple config-ini parsing"
+optional = true
+python-versions = ">=3.7"
+files = [
+ {file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"},
+ {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"},
+]
+
+[[package]]
+name = "numpy"
+version = "2.1.3"
+description = "Fundamental package for array computing in Python"
+optional = false
+python-versions = ">=3.10"
+files = [
+ {file = "numpy-2.1.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c894b4305373b9c5576d7a12b473702afdf48ce5369c074ba304cc5ad8730dff"},
+ {file = "numpy-2.1.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b47fbb433d3260adcd51eb54f92a2ffbc90a4595f8970ee00e064c644ac788f5"},
+ {file = "numpy-2.1.3-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:825656d0743699c529c5943554d223c021ff0494ff1442152ce887ef4f7561a1"},
+ {file = "numpy-2.1.3-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:6a4825252fcc430a182ac4dee5a505053d262c807f8a924603d411f6718b88fd"},
+ {file = "numpy-2.1.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e711e02f49e176a01d0349d82cb5f05ba4db7d5e7e0defd026328e5cfb3226d3"},
+ {file = "numpy-2.1.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:78574ac2d1a4a02421f25da9559850d59457bac82f2b8d7a44fe83a64f770098"},
+ {file = "numpy-2.1.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:c7662f0e3673fe4e832fe07b65c50342ea27d989f92c80355658c7f888fcc83c"},
+ {file = "numpy-2.1.3-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:fa2d1337dc61c8dc417fbccf20f6d1e139896a30721b7f1e832b2bb6ef4eb6c4"},
+ {file = "numpy-2.1.3-cp310-cp310-win32.whl", hash = "sha256:72dcc4a35a8515d83e76b58fdf8113a5c969ccd505c8a946759b24e3182d1f23"},
+ {file = "numpy-2.1.3-cp310-cp310-win_amd64.whl", hash = "sha256:ecc76a9ba2911d8d37ac01de72834d8849e55473457558e12995f4cd53e778e0"},
+ {file = "numpy-2.1.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4d1167c53b93f1f5d8a139a742b3c6f4d429b54e74e6b57d0eff40045187b15d"},
+ {file = "numpy-2.1.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c80e4a09b3d95b4e1cac08643f1152fa71a0a821a2d4277334c88d54b2219a41"},
+ {file = "numpy-2.1.3-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:576a1c1d25e9e02ed7fa5477f30a127fe56debd53b8d2c89d5578f9857d03ca9"},
+ {file = "numpy-2.1.3-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:973faafebaae4c0aaa1a1ca1ce02434554d67e628b8d805e61f874b84e136b09"},
+ {file = "numpy-2.1.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:762479be47a4863e261a840e8e01608d124ee1361e48b96916f38b119cfda04a"},
+ {file = "numpy-2.1.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc6f24b3d1ecc1eebfbf5d6051faa49af40b03be1aaa781ebdadcbc090b4539b"},
+ {file = "numpy-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:17ee83a1f4fef3c94d16dc1802b998668b5419362c8a4f4e8a491de1b41cc3ee"},
+ {file = "numpy-2.1.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:15cb89f39fa6d0bdfb600ea24b250e5f1a3df23f901f51c8debaa6a5d122b2f0"},
+ {file = "numpy-2.1.3-cp311-cp311-win32.whl", hash = "sha256:d9beb777a78c331580705326d2367488d5bc473b49a9bc3036c154832520aca9"},
+ {file = "numpy-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:d89dd2b6da69c4fff5e39c28a382199ddedc3a5be5390115608345dec660b9e2"},
+ {file = "numpy-2.1.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:f55ba01150f52b1027829b50d70ef1dafd9821ea82905b63936668403c3b471e"},
+ {file = "numpy-2.1.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:13138eadd4f4da03074851a698ffa7e405f41a0845a6b1ad135b81596e4e9958"},
+ {file = "numpy-2.1.3-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:a6b46587b14b888e95e4a24d7b13ae91fa22386c199ee7b418f449032b2fa3b8"},
+ {file = "numpy-2.1.3-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:0fa14563cc46422e99daef53d725d0c326e99e468a9320a240affffe87852564"},
+ {file = "numpy-2.1.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8637dcd2caa676e475503d1f8fdb327bc495554e10838019651b76d17b98e512"},
+ {file = "numpy-2.1.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2312b2aa89e1f43ecea6da6ea9a810d06aae08321609d8dc0d0eda6d946a541b"},
+ {file = "numpy-2.1.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:a38c19106902bb19351b83802531fea19dee18e5b37b36454f27f11ff956f7fc"},
+ {file = "numpy-2.1.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:02135ade8b8a84011cbb67dc44e07c58f28575cf9ecf8ab304e51c05528c19f0"},
+ {file = "numpy-2.1.3-cp312-cp312-win32.whl", hash = "sha256:e6988e90fcf617da2b5c78902fe8e668361b43b4fe26dbf2d7b0f8034d4cafb9"},
+ {file = "numpy-2.1.3-cp312-cp312-win_amd64.whl", hash = "sha256:0d30c543f02e84e92c4b1f415b7c6b5326cbe45ee7882b6b77db7195fb971e3a"},
+ {file = "numpy-2.1.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:96fe52fcdb9345b7cd82ecd34547fca4321f7656d500eca497eb7ea5a926692f"},
+ {file = "numpy-2.1.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:f653490b33e9c3a4c1c01d41bc2aef08f9475af51146e4a7710c450cf9761598"},
+ {file = "numpy-2.1.3-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:dc258a761a16daa791081d026f0ed4399b582712e6fc887a95af09df10c5ca57"},
+ {file = "numpy-2.1.3-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:016d0f6f5e77b0f0d45d77387ffa4bb89816b57c835580c3ce8e099ef830befe"},
+ {file = "numpy-2.1.3-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c181ba05ce8299c7aa3125c27b9c2167bca4a4445b7ce73d5febc411ca692e43"},
+ {file = "numpy-2.1.3-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5641516794ca9e5f8a4d17bb45446998c6554704d888f86df9b200e66bdcce56"},
+ {file = "numpy-2.1.3-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:ea4dedd6e394a9c180b33c2c872b92f7ce0f8e7ad93e9585312b0c5a04777a4a"},
+ {file = "numpy-2.1.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:b0df3635b9c8ef48bd3be5f862cf71b0a4716fa0e702155c45067c6b711ddcef"},
+ {file = "numpy-2.1.3-cp313-cp313-win32.whl", hash = "sha256:50ca6aba6e163363f132b5c101ba078b8cbd3fa92c7865fd7d4d62d9779ac29f"},
+ {file = "numpy-2.1.3-cp313-cp313-win_amd64.whl", hash = "sha256:747641635d3d44bcb380d950679462fae44f54b131be347d5ec2bce47d3df9ed"},
+ {file = "numpy-2.1.3-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:996bb9399059c5b82f76b53ff8bb686069c05acc94656bb259b1d63d04a9506f"},
+ {file = "numpy-2.1.3-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:45966d859916ad02b779706bb43b954281db43e185015df6eb3323120188f9e4"},
+ {file = "numpy-2.1.3-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:baed7e8d7481bfe0874b566850cb0b85243e982388b7b23348c6db2ee2b2ae8e"},
+ {file = "numpy-2.1.3-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:a9f7f672a3388133335589cfca93ed468509cb7b93ba3105fce780d04a6576a0"},
+ {file = "numpy-2.1.3-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d7aac50327da5d208db2eec22eb11e491e3fe13d22653dce51b0f4109101b408"},
+ {file = "numpy-2.1.3-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4394bc0dbd074b7f9b52024832d16e019decebf86caf909d94f6b3f77a8ee3b6"},
+ {file = "numpy-2.1.3-cp313-cp313t-musllinux_1_1_x86_64.whl", hash = "sha256:50d18c4358a0a8a53f12a8ba9d772ab2d460321e6a93d6064fc22443d189853f"},
+ {file = "numpy-2.1.3-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:14e253bd43fc6b37af4921b10f6add6925878a42a0c5fe83daee390bca80bc17"},
+ {file = "numpy-2.1.3-cp313-cp313t-win32.whl", hash = "sha256:08788d27a5fd867a663f6fc753fd7c3ad7e92747efc73c53bca2f19f8bc06f48"},
+ {file = "numpy-2.1.3-cp313-cp313t-win_amd64.whl", hash = "sha256:2564fbdf2b99b3f815f2107c1bbc93e2de8ee655a69c261363a1172a79a257d4"},
+ {file = "numpy-2.1.3-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:4f2015dfe437dfebbfce7c85c7b53d81ba49e71ba7eadbf1df40c915af75979f"},
+ {file = "numpy-2.1.3-pp310-pypy310_pp73-macosx_14_0_x86_64.whl", hash = "sha256:3522b0dfe983a575e6a9ab3a4a4dfe156c3e428468ff08ce582b9bb6bd1d71d4"},
+ {file = "numpy-2.1.3-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c006b607a865b07cd981ccb218a04fc86b600411d83d6fc261357f1c0966755d"},
+ {file = "numpy-2.1.3-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:e14e26956e6f1696070788252dcdff11b4aca4c3e8bd166e0df1bb8f315a67cb"},
+ {file = "numpy-2.1.3.tar.gz", hash = "sha256:aa08e04e08aaf974d4458def539dece0d28146d866a39da5639596f4921fd761"},
+]
+
+[[package]]
+name = "packaging"
+version = "24.2"
+description = "Core utilities for Python packages"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "packaging-24.2-py3-none-any.whl", hash = "sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759"},
+ {file = "packaging-24.2.tar.gz", hash = "sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f"},
+]
+
+[[package]]
+name = "pluggy"
+version = "1.5.0"
+description = "plugin and hook calling mechanisms for python"
+optional = true
+python-versions = ">=3.8"
+files = [
+ {file = "pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669"},
+ {file = "pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1"},
+]
+
+[package.extras]
+dev = ["pre-commit", "tox"]
+testing = ["pytest", "pytest-benchmark"]
+
+[[package]]
+name = "pytest"
+version = "8.3.4"
+description = "pytest: simple powerful testing with Python"
+optional = true
+python-versions = ">=3.8"
+files = [
+ {file = "pytest-8.3.4-py3-none-any.whl", hash = "sha256:50e16d954148559c9a74109af1eaf0c945ba2d8f30f0a3d3335edde19788b6f6"},
+ {file = "pytest-8.3.4.tar.gz", hash = "sha256:965370d062bce11e73868e0335abac31b4d3de0e82f4007408d242b4f8610761"},
+]
+
+[package.dependencies]
+colorama = {version = "*", markers = "sys_platform == \"win32\""}
+exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""}
+iniconfig = "*"
+packaging = "*"
+pluggy = ">=1.5,<2"
+tomli = {version = ">=1", markers = "python_version < \"3.11\""}
+
+[package.extras]
+dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
+
+[[package]]
+name = "pyyaml"
+version = "6.0.2"
+description = "YAML parser and emitter for Python"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "PyYAML-6.0.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0a9a2848a5b7feac301353437eb7d5957887edbf81d56e903999a75a3d743086"},
+ {file = "PyYAML-6.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:29717114e51c84ddfba879543fb232a6ed60086602313ca38cce623c1d62cfbf"},
+ {file = "PyYAML-6.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8824b5a04a04a047e72eea5cec3bc266db09e35de6bdfe34c9436ac5ee27d237"},
+ {file = "PyYAML-6.0.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7c36280e6fb8385e520936c3cb3b8042851904eba0e58d277dca80a5cfed590b"},
+ {file = "PyYAML-6.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ec031d5d2feb36d1d1a24380e4db6d43695f3748343d99434e6f5f9156aaa2ed"},
+ {file = "PyYAML-6.0.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:936d68689298c36b53b29f23c6dbb74de12b4ac12ca6cfe0e047bedceea56180"},
+ {file = "PyYAML-6.0.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:23502f431948090f597378482b4812b0caae32c22213aecf3b55325e049a6c68"},
+ {file = "PyYAML-6.0.2-cp310-cp310-win32.whl", hash = "sha256:2e99c6826ffa974fe6e27cdb5ed0021786b03fc98e5ee3c5bfe1fd5015f42b99"},
+ {file = "PyYAML-6.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:a4d3091415f010369ae4ed1fc6b79def9416358877534caf6a0fdd2146c87a3e"},
+ {file = "PyYAML-6.0.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:cc1c1159b3d456576af7a3e4d1ba7e6924cb39de8f67111c735f6fc832082774"},
+ {file = "PyYAML-6.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1e2120ef853f59c7419231f3bf4e7021f1b936f6ebd222406c3b60212205d2ee"},
+ {file = "PyYAML-6.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5d225db5a45f21e78dd9358e58a98702a0302f2659a3c6cd320564b75b86f47c"},
+ {file = "PyYAML-6.0.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5ac9328ec4831237bec75defaf839f7d4564be1e6b25ac710bd1a96321cc8317"},
+ {file = "PyYAML-6.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ad2a3decf9aaba3d29c8f537ac4b243e36bef957511b4766cb0057d32b0be85"},
+ {file = "PyYAML-6.0.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ff3824dc5261f50c9b0dfb3be22b4567a6f938ccce4587b38952d85fd9e9afe4"},
+ {file = "PyYAML-6.0.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:797b4f722ffa07cc8d62053e4cff1486fa6dc094105d13fea7b1de7d8bf71c9e"},
+ {file = "PyYAML-6.0.2-cp311-cp311-win32.whl", hash = "sha256:11d8f3dd2b9c1207dcaf2ee0bbbfd5991f571186ec9cc78427ba5bd32afae4b5"},
+ {file = "PyYAML-6.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:e10ce637b18caea04431ce14fabcf5c64a1c61ec9c56b071a4b7ca131ca52d44"},
+ {file = "PyYAML-6.0.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:c70c95198c015b85feafc136515252a261a84561b7b1d51e3384e0655ddf25ab"},
+ {file = "PyYAML-6.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ce826d6ef20b1bc864f0a68340c8b3287705cae2f8b4b1d932177dcc76721725"},
+ {file = "PyYAML-6.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1f71ea527786de97d1a0cc0eacd1defc0985dcf6b3f17bb77dcfc8c34bec4dc5"},
+ {file = "PyYAML-6.0.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9b22676e8097e9e22e36d6b7bda33190d0d400f345f23d4065d48f4ca7ae0425"},
+ {file = "PyYAML-6.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80bab7bfc629882493af4aa31a4cfa43a4c57c83813253626916b8c7ada83476"},
+ {file = "PyYAML-6.0.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:0833f8694549e586547b576dcfaba4a6b55b9e96098b36cdc7ebefe667dfed48"},
+ {file = "PyYAML-6.0.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8b9c7197f7cb2738065c481a0461e50ad02f18c78cd75775628afb4d7137fb3b"},
+ {file = "PyYAML-6.0.2-cp312-cp312-win32.whl", hash = "sha256:ef6107725bd54b262d6dedcc2af448a266975032bc85ef0172c5f059da6325b4"},
+ {file = "PyYAML-6.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:7e7401d0de89a9a855c839bc697c079a4af81cf878373abd7dc625847d25cbd8"},
+ {file = "PyYAML-6.0.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:efdca5630322a10774e8e98e1af481aad470dd62c3170801852d752aa7a783ba"},
+ {file = "PyYAML-6.0.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:50187695423ffe49e2deacb8cd10510bc361faac997de9efef88badc3bb9e2d1"},
+ {file = "PyYAML-6.0.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0ffe8360bab4910ef1b9e87fb812d8bc0a308b0d0eef8c8f44e0254ab3b07133"},
+ {file = "PyYAML-6.0.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:17e311b6c678207928d649faa7cb0d7b4c26a0ba73d41e99c4fff6b6c3276484"},
+ {file = "PyYAML-6.0.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:70b189594dbe54f75ab3a1acec5f1e3faa7e8cf2f1e08d9b561cb41b845f69d5"},
+ {file = "PyYAML-6.0.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:41e4e3953a79407c794916fa277a82531dd93aad34e29c2a514c2c0c5fe971cc"},
+ {file = "PyYAML-6.0.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:68ccc6023a3400877818152ad9a1033e3db8625d899c72eacb5a668902e4d652"},
+ {file = "PyYAML-6.0.2-cp313-cp313-win32.whl", hash = "sha256:bc2fa7c6b47d6bc618dd7fb02ef6fdedb1090ec036abab80d4681424b84c1183"},
+ {file = "PyYAML-6.0.2-cp313-cp313-win_amd64.whl", hash = "sha256:8388ee1976c416731879ac16da0aff3f63b286ffdd57cdeb95f3f2e085687563"},
+ {file = "PyYAML-6.0.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:24471b829b3bf607e04e88d79542a9d48bb037c2267d7927a874e6c205ca7e9a"},
+ {file = "PyYAML-6.0.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d7fded462629cfa4b685c5416b949ebad6cec74af5e2d42905d41e257e0869f5"},
+ {file = "PyYAML-6.0.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d84a1718ee396f54f3a086ea0a66d8e552b2ab2017ef8b420e92edbc841c352d"},
+ {file = "PyYAML-6.0.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9056c1ecd25795207ad294bcf39f2db3d845767be0ea6e6a34d856f006006083"},
+ {file = "PyYAML-6.0.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:82d09873e40955485746739bcb8b4586983670466c23382c19cffecbf1fd8706"},
+ {file = "PyYAML-6.0.2-cp38-cp38-win32.whl", hash = "sha256:43fa96a3ca0d6b1812e01ced1044a003533c47f6ee8aca31724f78e93ccc089a"},
+ {file = "PyYAML-6.0.2-cp38-cp38-win_amd64.whl", hash = "sha256:01179a4a8559ab5de078078f37e5c1a30d76bb88519906844fd7bdea1b7729ff"},
+ {file = "PyYAML-6.0.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:688ba32a1cffef67fd2e9398a2efebaea461578b0923624778664cc1c914db5d"},
+ {file = "PyYAML-6.0.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a8786accb172bd8afb8be14490a16625cbc387036876ab6ba70912730faf8e1f"},
+ {file = "PyYAML-6.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d8e03406cac8513435335dbab54c0d385e4a49e4945d2909a581c83647ca0290"},
+ {file = "PyYAML-6.0.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f753120cb8181e736c57ef7636e83f31b9c0d1722c516f7e86cf15b7aa57ff12"},
+ {file = "PyYAML-6.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3b1fdb9dc17f5a7677423d508ab4f243a726dea51fa5e70992e59a7411c89d19"},
+ {file = "PyYAML-6.0.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:0b69e4ce7a131fe56b7e4d770c67429700908fc0752af059838b1cfb41960e4e"},
+ {file = "PyYAML-6.0.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a9f8c2e67970f13b16084e04f134610fd1d374bf477b17ec1599185cf611d725"},
+ {file = "PyYAML-6.0.2-cp39-cp39-win32.whl", hash = "sha256:6395c297d42274772abc367baaa79683958044e5d3835486c16da75d2a694631"},
+ {file = "PyYAML-6.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:39693e1f8320ae4f43943590b49779ffb98acb81f788220ea932a6b6c51004d8"},
+ {file = "pyyaml-6.0.2.tar.gz", hash = "sha256:d584d9ec91ad65861cc08d42e834324ef890a082e591037abe114850ff7bbc3e"},
+]
+
+[[package]]
+name = "regex"
+version = "2024.11.6"
+description = "Alternative regular expression module, to replace re."
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "regex-2024.11.6-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:ff590880083d60acc0433f9c3f713c51f7ac6ebb9adf889c79a261ecf541aa91"},
+ {file = "regex-2024.11.6-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:658f90550f38270639e83ce492f27d2c8d2cd63805c65a13a14d36ca126753f0"},
+ {file = "regex-2024.11.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:164d8b7b3b4bcb2068b97428060b2a53be050085ef94eca7f240e7947f1b080e"},
+ {file = "regex-2024.11.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d3660c82f209655a06b587d55e723f0b813d3a7db2e32e5e7dc64ac2a9e86fde"},
+ {file = "regex-2024.11.6-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d22326fcdef5e08c154280b71163ced384b428343ae16a5ab2b3354aed12436e"},
+ {file = "regex-2024.11.6-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f1ac758ef6aebfc8943560194e9fd0fa18bcb34d89fd8bd2af18183afd8da3a2"},
+ {file = "regex-2024.11.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:997d6a487ff00807ba810e0f8332c18b4eb8d29463cfb7c820dc4b6e7562d0cf"},
+ {file = "regex-2024.11.6-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:02a02d2bb04fec86ad61f3ea7f49c015a0681bf76abb9857f945d26159d2968c"},
+ {file = "regex-2024.11.6-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:f02f93b92358ee3f78660e43b4b0091229260c5d5c408d17d60bf26b6c900e86"},
+ {file = "regex-2024.11.6-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:06eb1be98df10e81ebaded73fcd51989dcf534e3c753466e4b60c4697a003b67"},
+ {file = "regex-2024.11.6-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:040df6fe1a5504eb0f04f048e6d09cd7c7110fef851d7c567a6b6e09942feb7d"},
+ {file = "regex-2024.11.6-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:fdabbfc59f2c6edba2a6622c647b716e34e8e3867e0ab975412c5c2f79b82da2"},
+ {file = "regex-2024.11.6-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:8447d2d39b5abe381419319f942de20b7ecd60ce86f16a23b0698f22e1b70008"},
+ {file = "regex-2024.11.6-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:da8f5fc57d1933de22a9e23eec290a0d8a5927a5370d24bda9a6abe50683fe62"},
+ {file = "regex-2024.11.6-cp310-cp310-win32.whl", hash = "sha256:b489578720afb782f6ccf2840920f3a32e31ba28a4b162e13900c3e6bd3f930e"},
+ {file = "regex-2024.11.6-cp310-cp310-win_amd64.whl", hash = "sha256:5071b2093e793357c9d8b2929dfc13ac5f0a6c650559503bb81189d0a3814519"},
+ {file = "regex-2024.11.6-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:5478c6962ad548b54a591778e93cd7c456a7a29f8eca9c49e4f9a806dcc5d638"},
+ {file = "regex-2024.11.6-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:2c89a8cc122b25ce6945f0423dc1352cb9593c68abd19223eebbd4e56612c5b7"},
+ {file = "regex-2024.11.6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:94d87b689cdd831934fa3ce16cc15cd65748e6d689f5d2b8f4f4df2065c9fa20"},
+ {file = "regex-2024.11.6-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1062b39a0a2b75a9c694f7a08e7183a80c63c0d62b301418ffd9c35f55aaa114"},
+ {file = "regex-2024.11.6-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:167ed4852351d8a750da48712c3930b031f6efdaa0f22fa1933716bfcd6bf4a3"},
+ {file = "regex-2024.11.6-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2d548dafee61f06ebdb584080621f3e0c23fff312f0de1afc776e2a2ba99a74f"},
+ {file = "regex-2024.11.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f2a19f302cd1ce5dd01a9099aaa19cae6173306d1302a43b627f62e21cf18ac0"},
+ {file = "regex-2024.11.6-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bec9931dfb61ddd8ef2ebc05646293812cb6b16b60cf7c9511a832b6f1854b55"},
+ {file = "regex-2024.11.6-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:9714398225f299aa85267fd222f7142fcb5c769e73d7733344efc46f2ef5cf89"},
+ {file = "regex-2024.11.6-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:202eb32e89f60fc147a41e55cb086db2a3f8cb82f9a9a88440dcfc5d37faae8d"},
+ {file = "regex-2024.11.6-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:4181b814e56078e9b00427ca358ec44333765f5ca1b45597ec7446d3a1ef6e34"},
+ {file = "regex-2024.11.6-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:068376da5a7e4da51968ce4c122a7cd31afaaec4fccc7856c92f63876e57b51d"},
+ {file = "regex-2024.11.6-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:ac10f2c4184420d881a3475fb2c6f4d95d53a8d50209a2500723d831036f7c45"},
+ {file = "regex-2024.11.6-cp311-cp311-win32.whl", hash = "sha256:c36f9b6f5f8649bb251a5f3f66564438977b7ef8386a52460ae77e6070d309d9"},
+ {file = "regex-2024.11.6-cp311-cp311-win_amd64.whl", hash = "sha256:02e28184be537f0e75c1f9b2f8847dc51e08e6e171c6bde130b2687e0c33cf60"},
+ {file = "regex-2024.11.6-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:52fb28f528778f184f870b7cf8f225f5eef0a8f6e3778529bdd40c7b3920796a"},
+ {file = "regex-2024.11.6-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:fdd6028445d2460f33136c55eeb1f601ab06d74cb3347132e1c24250187500d9"},
+ {file = "regex-2024.11.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:805e6b60c54bf766b251e94526ebad60b7de0c70f70a4e6210ee2891acb70bf2"},
+ {file = "regex-2024.11.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b85c2530be953a890eaffde05485238f07029600e8f098cdf1848d414a8b45e4"},
+ {file = "regex-2024.11.6-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bb26437975da7dc36b7efad18aa9dd4ea569d2357ae6b783bf1118dabd9ea577"},
+ {file = "regex-2024.11.6-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:abfa5080c374a76a251ba60683242bc17eeb2c9818d0d30117b4486be10c59d3"},
+ {file = "regex-2024.11.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:70b7fa6606c2881c1db9479b0eaa11ed5dfa11c8d60a474ff0e095099f39d98e"},
+ {file = "regex-2024.11.6-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0c32f75920cf99fe6b6c539c399a4a128452eaf1af27f39bce8909c9a3fd8cbe"},
+ {file = "regex-2024.11.6-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:982e6d21414e78e1f51cf595d7f321dcd14de1f2881c5dc6a6e23bbbbd68435e"},
+ {file = "regex-2024.11.6-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:a7c2155f790e2fb448faed6dd241386719802296ec588a8b9051c1f5c481bc29"},
+ {file = "regex-2024.11.6-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:149f5008d286636e48cd0b1dd65018548944e495b0265b45e1bffecce1ef7f39"},
+ {file = "regex-2024.11.6-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:e5364a4502efca094731680e80009632ad6624084aff9a23ce8c8c6820de3e51"},
+ {file = "regex-2024.11.6-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:0a86e7eeca091c09e021db8eb72d54751e527fa47b8d5787caf96d9831bd02ad"},
+ {file = "regex-2024.11.6-cp312-cp312-win32.whl", hash = "sha256:32f9a4c643baad4efa81d549c2aadefaeba12249b2adc5af541759237eee1c54"},
+ {file = "regex-2024.11.6-cp312-cp312-win_amd64.whl", hash = "sha256:a93c194e2df18f7d264092dc8539b8ffb86b45b899ab976aa15d48214138e81b"},
+ {file = "regex-2024.11.6-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:a6ba92c0bcdf96cbf43a12c717eae4bc98325ca3730f6b130ffa2e3c3c723d84"},
+ {file = "regex-2024.11.6-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:525eab0b789891ac3be914d36893bdf972d483fe66551f79d3e27146191a37d4"},
+ {file = "regex-2024.11.6-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:086a27a0b4ca227941700e0b31425e7a28ef1ae8e5e05a33826e17e47fbfdba0"},
+ {file = "regex-2024.11.6-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bde01f35767c4a7899b7eb6e823b125a64de314a8ee9791367c9a34d56af18d0"},
+ {file = "regex-2024.11.6-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b583904576650166b3d920d2bcce13971f6f9e9a396c673187f49811b2769dc7"},
+ {file = "regex-2024.11.6-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1c4de13f06a0d54fa0d5ab1b7138bfa0d883220965a29616e3ea61b35d5f5fc7"},
+ {file = "regex-2024.11.6-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3cde6e9f2580eb1665965ce9bf17ff4952f34f5b126beb509fee8f4e994f143c"},
+ {file = "regex-2024.11.6-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0d7f453dca13f40a02b79636a339c5b62b670141e63efd511d3f8f73fba162b3"},
+ {file = "regex-2024.11.6-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:59dfe1ed21aea057a65c6b586afd2a945de04fc7db3de0a6e3ed5397ad491b07"},
+ {file = "regex-2024.11.6-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:b97c1e0bd37c5cd7902e65f410779d39eeda155800b65fc4d04cc432efa9bc6e"},
+ {file = "regex-2024.11.6-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:f9d1e379028e0fc2ae3654bac3cbbef81bf3fd571272a42d56c24007979bafb6"},
+ {file = "regex-2024.11.6-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:13291b39131e2d002a7940fb176e120bec5145f3aeb7621be6534e46251912c4"},
+ {file = "regex-2024.11.6-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4f51f88c126370dcec4908576c5a627220da6c09d0bff31cfa89f2523843316d"},
+ {file = "regex-2024.11.6-cp313-cp313-win32.whl", hash = "sha256:63b13cfd72e9601125027202cad74995ab26921d8cd935c25f09c630436348ff"},
+ {file = "regex-2024.11.6-cp313-cp313-win_amd64.whl", hash = "sha256:2b3361af3198667e99927da8b84c1b010752fa4b1115ee30beaa332cabc3ef1a"},
+ {file = "regex-2024.11.6-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:3a51ccc315653ba012774efca4f23d1d2a8a8f278a6072e29c7147eee7da446b"},
+ {file = "regex-2024.11.6-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ad182d02e40de7459b73155deb8996bbd8e96852267879396fb274e8700190e3"},
+ {file = "regex-2024.11.6-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:ba9b72e5643641b7d41fa1f6d5abda2c9a263ae835b917348fc3c928182ad467"},
+ {file = "regex-2024.11.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:40291b1b89ca6ad8d3f2b82782cc33807f1406cf68c8d440861da6304d8ffbbd"},
+ {file = "regex-2024.11.6-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cdf58d0e516ee426a48f7b2c03a332a4114420716d55769ff7108c37a09951bf"},
+ {file = "regex-2024.11.6-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a36fdf2af13c2b14738f6e973aba563623cb77d753bbbd8d414d18bfaa3105dd"},
+ {file = "regex-2024.11.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d1cee317bfc014c2419a76bcc87f071405e3966da434e03e13beb45f8aced1a6"},
+ {file = "regex-2024.11.6-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:50153825ee016b91549962f970d6a4442fa106832e14c918acd1c8e479916c4f"},
+ {file = "regex-2024.11.6-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:ea1bfda2f7162605f6e8178223576856b3d791109f15ea99a9f95c16a7636fb5"},
+ {file = "regex-2024.11.6-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:df951c5f4a1b1910f1a99ff42c473ff60f8225baa1cdd3539fe2819d9543e9df"},
+ {file = "regex-2024.11.6-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:072623554418a9911446278f16ecb398fb3b540147a7828c06e2011fa531e773"},
+ {file = "regex-2024.11.6-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:f654882311409afb1d780b940234208a252322c24a93b442ca714d119e68086c"},
+ {file = "regex-2024.11.6-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:89d75e7293d2b3e674db7d4d9b1bee7f8f3d1609428e293771d1a962617150cc"},
+ {file = "regex-2024.11.6-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:f65557897fc977a44ab205ea871b690adaef6b9da6afda4790a2484b04293a5f"},
+ {file = "regex-2024.11.6-cp38-cp38-win32.whl", hash = "sha256:6f44ec28b1f858c98d3036ad5d7d0bfc568bdd7a74f9c24e25f41ef1ebfd81a4"},
+ {file = "regex-2024.11.6-cp38-cp38-win_amd64.whl", hash = "sha256:bb8f74f2f10dbf13a0be8de623ba4f9491faf58c24064f32b65679b021ed0001"},
+ {file = "regex-2024.11.6-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:5704e174f8ccab2026bd2f1ab6c510345ae8eac818b613d7d73e785f1310f839"},
+ {file = "regex-2024.11.6-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:220902c3c5cc6af55d4fe19ead504de80eb91f786dc102fbd74894b1551f095e"},
+ {file = "regex-2024.11.6-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5e7e351589da0850c125f1600a4c4ba3c722efefe16b297de54300f08d734fbf"},
+ {file = "regex-2024.11.6-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5056b185ca113c88e18223183aa1a50e66507769c9640a6ff75859619d73957b"},
+ {file = "regex-2024.11.6-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2e34b51b650b23ed3354b5a07aab37034d9f923db2a40519139af34f485f77d0"},
+ {file = "regex-2024.11.6-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5670bce7b200273eee1840ef307bfa07cda90b38ae56e9a6ebcc9f50da9c469b"},
+ {file = "regex-2024.11.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:08986dce1339bc932923e7d1232ce9881499a0e02925f7402fb7c982515419ef"},
+ {file = "regex-2024.11.6-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:93c0b12d3d3bc25af4ebbf38f9ee780a487e8bf6954c115b9f015822d3bb8e48"},
+ {file = "regex-2024.11.6-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:764e71f22ab3b305e7f4c21f1a97e1526a25ebdd22513e251cf376760213da13"},
+ {file = "regex-2024.11.6-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:f056bf21105c2515c32372bbc057f43eb02aae2fda61052e2f7622c801f0b4e2"},
+ {file = "regex-2024.11.6-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:69ab78f848845569401469da20df3e081e6b5a11cb086de3eed1d48f5ed57c95"},
+ {file = "regex-2024.11.6-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:86fddba590aad9208e2fa8b43b4c098bb0ec74f15718bb6a704e3c63e2cef3e9"},
+ {file = "regex-2024.11.6-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:684d7a212682996d21ca12ef3c17353c021fe9de6049e19ac8481ec35574a70f"},
+ {file = "regex-2024.11.6-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:a03e02f48cd1abbd9f3b7e3586d97c8f7a9721c436f51a5245b3b9483044480b"},
+ {file = "regex-2024.11.6-cp39-cp39-win32.whl", hash = "sha256:41758407fc32d5c3c5de163888068cfee69cb4c2be844e7ac517a52770f9af57"},
+ {file = "regex-2024.11.6-cp39-cp39-win_amd64.whl", hash = "sha256:b2837718570f95dd41675328e111345f9b7095d821bac435aac173ac80b19983"},
+ {file = "regex-2024.11.6.tar.gz", hash = "sha256:7ab159b063c52a0333c884e4679f8d7a85112ee3078fe3d9004b2dd875585519"},
+]
+
+[[package]]
+name = "requests"
+version = "2.32.3"
+description = "Python HTTP for Humans."
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6"},
+ {file = "requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760"},
+]
+
+[package.dependencies]
+certifi = ">=2017.4.17"
+charset-normalizer = ">=2,<4"
+idna = ">=2.5,<4"
+urllib3 = ">=1.21.1,<3"
+
+[package.extras]
+socks = ["PySocks (>=1.5.6,!=1.5.7)"]
+use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
+
+[[package]]
+name = "safetensors"
+version = "0.4.5"
+description = ""
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "safetensors-0.4.5-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:a63eaccd22243c67e4f2b1c3e258b257effc4acd78f3b9d397edc8cf8f1298a7"},
+ {file = "safetensors-0.4.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:23fc9b4ec7b602915cbb4ec1a7c1ad96d2743c322f20ab709e2c35d1b66dad27"},
+ {file = "safetensors-0.4.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6885016f34bef80ea1085b7e99b3c1f92cb1be78a49839203060f67b40aee761"},
+ {file = "safetensors-0.4.5-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:133620f443450429322f238fda74d512c4008621227fccf2f8cf4a76206fea7c"},
+ {file = "safetensors-0.4.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4fb3e0609ec12d2a77e882f07cced530b8262027f64b75d399f1504ffec0ba56"},
+ {file = "safetensors-0.4.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d0f1dd769f064adc33831f5e97ad07babbd728427f98e3e1db6902e369122737"},
+ {file = "safetensors-0.4.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c6d156bdb26732feada84f9388a9f135528c1ef5b05fae153da365ad4319c4c5"},
+ {file = "safetensors-0.4.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:9e347d77e2c77eb7624400ccd09bed69d35c0332f417ce8c048d404a096c593b"},
+ {file = "safetensors-0.4.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:9f556eea3aec1d3d955403159fe2123ddd68e880f83954ee9b4a3f2e15e716b6"},
+ {file = "safetensors-0.4.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:9483f42be3b6bc8ff77dd67302de8ae411c4db39f7224dec66b0eb95822e4163"},
+ {file = "safetensors-0.4.5-cp310-none-win32.whl", hash = "sha256:7389129c03fadd1ccc37fd1ebbc773f2b031483b04700923c3511d2a939252cc"},
+ {file = "safetensors-0.4.5-cp310-none-win_amd64.whl", hash = "sha256:e98ef5524f8b6620c8cdef97220c0b6a5c1cef69852fcd2f174bb96c2bb316b1"},
+ {file = "safetensors-0.4.5-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:21f848d7aebd5954f92538552d6d75f7c1b4500f51664078b5b49720d180e47c"},
+ {file = "safetensors-0.4.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:bb07000b19d41e35eecef9a454f31a8b4718a185293f0d0b1c4b61d6e4487971"},
+ {file = "safetensors-0.4.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:09dedf7c2fda934ee68143202acff6e9e8eb0ddeeb4cfc24182bef999efa9f42"},
+ {file = "safetensors-0.4.5-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:59b77e4b7a708988d84f26de3ebead61ef1659c73dcbc9946c18f3b1786d2688"},
+ {file = "safetensors-0.4.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5d3bc83e14d67adc2e9387e511097f254bd1b43c3020440e708858c684cbac68"},
+ {file = "safetensors-0.4.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:39371fc551c1072976073ab258c3119395294cf49cdc1f8476794627de3130df"},
+ {file = "safetensors-0.4.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a6c19feda32b931cae0acd42748a670bdf56bee6476a046af20181ad3fee4090"},
+ {file = "safetensors-0.4.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a659467495de201e2f282063808a41170448c78bada1e62707b07a27b05e6943"},
+ {file = "safetensors-0.4.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:bad5e4b2476949bcd638a89f71b6916fa9a5cae5c1ae7eede337aca2100435c0"},
+ {file = "safetensors-0.4.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:a3a315a6d0054bc6889a17f5668a73f94f7fe55121ff59e0a199e3519c08565f"},
+ {file = "safetensors-0.4.5-cp311-none-win32.whl", hash = "sha256:a01e232e6d3d5cf8b1667bc3b657a77bdab73f0743c26c1d3c5dd7ce86bd3a92"},
+ {file = "safetensors-0.4.5-cp311-none-win_amd64.whl", hash = "sha256:cbd39cae1ad3e3ef6f63a6f07296b080c951f24cec60188378e43d3713000c04"},
+ {file = "safetensors-0.4.5-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:473300314e026bd1043cef391bb16a8689453363381561b8a3e443870937cc1e"},
+ {file = "safetensors-0.4.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:801183a0f76dc647f51a2d9141ad341f9665602a7899a693207a82fb102cc53e"},
+ {file = "safetensors-0.4.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1524b54246e422ad6fb6aea1ac71edeeb77666efa67230e1faf6999df9b2e27f"},
+ {file = "safetensors-0.4.5-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b3139098e3e8b2ad7afbca96d30ad29157b50c90861084e69fcb80dec7430461"},
+ {file = "safetensors-0.4.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:65573dc35be9059770808e276b017256fa30058802c29e1038eb1c00028502ea"},
+ {file = "safetensors-0.4.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fd33da8e9407559f8779c82a0448e2133737f922d71f884da27184549416bfed"},
+ {file = "safetensors-0.4.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3685ce7ed036f916316b567152482b7e959dc754fcc4a8342333d222e05f407c"},
+ {file = "safetensors-0.4.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:dde2bf390d25f67908278d6f5d59e46211ef98e44108727084d4637ee70ab4f1"},
+ {file = "safetensors-0.4.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:7469d70d3de970b1698d47c11ebbf296a308702cbaae7fcb993944751cf985f4"},
+ {file = "safetensors-0.4.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:3a6ba28118636a130ccbb968bc33d4684c48678695dba2590169d5ab03a45646"},
+ {file = "safetensors-0.4.5-cp312-none-win32.whl", hash = "sha256:c859c7ed90b0047f58ee27751c8e56951452ed36a67afee1b0a87847d065eec6"},
+ {file = "safetensors-0.4.5-cp312-none-win_amd64.whl", hash = "sha256:b5a8810ad6a6f933fff6c276eae92c1da217b39b4d8b1bc1c0b8af2d270dc532"},
+ {file = "safetensors-0.4.5-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:25e5f8e2e92a74f05b4ca55686234c32aac19927903792b30ee6d7bd5653d54e"},
+ {file = "safetensors-0.4.5-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:81efb124b58af39fcd684254c645e35692fea81c51627259cdf6d67ff4458916"},
+ {file = "safetensors-0.4.5-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:585f1703a518b437f5103aa9cf70e9bd437cb78eea9c51024329e4fb8a3e3679"},
+ {file = "safetensors-0.4.5-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:4b99fbf72e3faf0b2f5f16e5e3458b93b7d0a83984fe8d5364c60aa169f2da89"},
+ {file = "safetensors-0.4.5-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b17b299ca9966ca983ecda1c0791a3f07f9ca6ab5ded8ef3d283fff45f6bcd5f"},
+ {file = "safetensors-0.4.5-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:76ded72f69209c9780fdb23ea89e56d35c54ae6abcdec67ccb22af8e696e449a"},
+ {file = "safetensors-0.4.5-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2783956926303dcfeb1de91a4d1204cd4089ab441e622e7caee0642281109db3"},
+ {file = "safetensors-0.4.5-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d94581aab8c6b204def4d7320f07534d6ee34cd4855688004a4354e63b639a35"},
+ {file = "safetensors-0.4.5-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:67e1e7cb8678bb1b37ac48ec0df04faf689e2f4e9e81e566b5c63d9f23748523"},
+ {file = "safetensors-0.4.5-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:dbd280b07e6054ea68b0cb4b16ad9703e7d63cd6890f577cb98acc5354780142"},
+ {file = "safetensors-0.4.5-cp37-cp37m-macosx_10_12_x86_64.whl", hash = "sha256:77d9b228da8374c7262046a36c1f656ba32a93df6cc51cd4453af932011e77f1"},
+ {file = "safetensors-0.4.5-cp37-cp37m-macosx_11_0_arm64.whl", hash = "sha256:500cac01d50b301ab7bb192353317035011c5ceeef0fca652f9f43c000bb7f8d"},
+ {file = "safetensors-0.4.5-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:75331c0c746f03158ded32465b7d0b0e24c5a22121743662a2393439c43a45cf"},
+ {file = "safetensors-0.4.5-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:670e95fe34e0d591d0529e5e59fd9d3d72bc77b1444fcaa14dccda4f36b5a38b"},
+ {file = "safetensors-0.4.5-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:098923e2574ff237c517d6e840acada8e5b311cb1fa226019105ed82e9c3b62f"},
+ {file = "safetensors-0.4.5-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:13ca0902d2648775089fa6a0c8fc9e6390c5f8ee576517d33f9261656f851e3f"},
+ {file = "safetensors-0.4.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5f0032bedc869c56f8d26259fe39cd21c5199cd57f2228d817a0e23e8370af25"},
+ {file = "safetensors-0.4.5-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f4b15f51b4f8f2a512341d9ce3475cacc19c5fdfc5db1f0e19449e75f95c7dc8"},
+ {file = "safetensors-0.4.5-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:f6594d130d0ad933d885c6a7b75c5183cb0e8450f799b80a39eae2b8508955eb"},
+ {file = "safetensors-0.4.5-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:60c828a27e852ded2c85fc0f87bf1ec20e464c5cd4d56ff0e0711855cc2e17f8"},
+ {file = "safetensors-0.4.5-cp37-none-win32.whl", hash = "sha256:6d3de65718b86c3eeaa8b73a9c3d123f9307a96bbd7be9698e21e76a56443af5"},
+ {file = "safetensors-0.4.5-cp37-none-win_amd64.whl", hash = "sha256:5a2d68a523a4cefd791156a4174189a4114cf0bf9c50ceb89f261600f3b2b81a"},
+ {file = "safetensors-0.4.5-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:e7a97058f96340850da0601a3309f3d29d6191b0702b2da201e54c6e3e44ccf0"},
+ {file = "safetensors-0.4.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:63bfd425e25f5c733f572e2246e08a1c38bd6f2e027d3f7c87e2e43f228d1345"},
+ {file = "safetensors-0.4.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f3664ac565d0e809b0b929dae7ccd74e4d3273cd0c6d1220c6430035befb678e"},
+ {file = "safetensors-0.4.5-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:313514b0b9b73ff4ddfb4edd71860696dbe3c1c9dc4d5cc13dbd74da283d2cbf"},
+ {file = "safetensors-0.4.5-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:31fa33ee326f750a2f2134a6174773c281d9a266ccd000bd4686d8021f1f3dac"},
+ {file = "safetensors-0.4.5-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:09566792588d77b68abe53754c9f1308fadd35c9f87be939e22c623eaacbed6b"},
+ {file = "safetensors-0.4.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:309aaec9b66cbf07ad3a2e5cb8a03205663324fea024ba391594423d0f00d9fe"},
+ {file = "safetensors-0.4.5-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:53946c5813b8f9e26103c5efff4a931cc45d874f45229edd68557ffb35ffb9f8"},
+ {file = "safetensors-0.4.5-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:868f9df9e99ad1e7f38c52194063a982bc88fedc7d05096f4f8160403aaf4bd6"},
+ {file = "safetensors-0.4.5-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:9cc9449bd0b0bc538bd5e268221f0c5590bc5c14c1934a6ae359d44410dc68c4"},
+ {file = "safetensors-0.4.5-cp38-none-win32.whl", hash = "sha256:83c4f13a9e687335c3928f615cd63a37e3f8ef072a3f2a0599fa09f863fb06a2"},
+ {file = "safetensors-0.4.5-cp38-none-win_amd64.whl", hash = "sha256:b98d40a2ffa560653f6274e15b27b3544e8e3713a44627ce268f419f35c49478"},
+ {file = "safetensors-0.4.5-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:cf727bb1281d66699bef5683b04d98c894a2803442c490a8d45cd365abfbdeb2"},
+ {file = "safetensors-0.4.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:96f1d038c827cdc552d97e71f522e1049fef0542be575421f7684756a748e457"},
+ {file = "safetensors-0.4.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:139fbee92570ecea774e6344fee908907db79646d00b12c535f66bc78bd5ea2c"},
+ {file = "safetensors-0.4.5-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c36302c1c69eebb383775a89645a32b9d266878fab619819ce660309d6176c9b"},
+ {file = "safetensors-0.4.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d641f5b8149ea98deb5ffcf604d764aad1de38a8285f86771ce1abf8e74c4891"},
+ {file = "safetensors-0.4.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b4db6a61d968de73722b858038c616a1bebd4a86abe2688e46ca0cc2d17558f2"},
+ {file = "safetensors-0.4.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b75a616e02f21b6f1d5785b20cecbab5e2bd3f6358a90e8925b813d557666ec1"},
+ {file = "safetensors-0.4.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:788ee7d04cc0e0e7f944c52ff05f52a4415b312f5efd2ee66389fb7685ee030c"},
+ {file = "safetensors-0.4.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:87bc42bd04fd9ca31396d3ca0433db0be1411b6b53ac5a32b7845a85d01ffc2e"},
+ {file = "safetensors-0.4.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:4037676c86365a721a8c9510323a51861d703b399b78a6b4486a54a65a975fca"},
+ {file = "safetensors-0.4.5-cp39-none-win32.whl", hash = "sha256:1500418454529d0ed5c1564bda376c4ddff43f30fce9517d9bee7bcce5a8ef50"},
+ {file = "safetensors-0.4.5-cp39-none-win_amd64.whl", hash = "sha256:9d1a94b9d793ed8fe35ab6d5cea28d540a46559bafc6aae98f30ee0867000cab"},
+ {file = "safetensors-0.4.5-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:fdadf66b5a22ceb645d5435a0be7a0292ce59648ca1d46b352f13cff3ea80410"},
+ {file = "safetensors-0.4.5-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:d42ffd4c2259f31832cb17ff866c111684c87bd930892a1ba53fed28370c918c"},
+ {file = "safetensors-0.4.5-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dd8a1f6d2063a92cd04145c7fd9e31a1c7d85fbec20113a14b487563fdbc0597"},
+ {file = "safetensors-0.4.5-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:951d2fcf1817f4fb0ef0b48f6696688a4e852a95922a042b3f96aaa67eedc920"},
+ {file = "safetensors-0.4.5-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:6ac85d9a8c1af0e3132371d9f2d134695a06a96993c2e2f0bbe25debb9e3f67a"},
+ {file = "safetensors-0.4.5-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:e3cec4a29eb7fe8da0b1c7988bc3828183080439dd559f720414450de076fcab"},
+ {file = "safetensors-0.4.5-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:21742b391b859e67b26c0b2ac37f52c9c0944a879a25ad2f9f9f3cd61e7fda8f"},
+ {file = "safetensors-0.4.5-pp37-pypy37_pp73-macosx_10_12_x86_64.whl", hash = "sha256:c7db3006a4915151ce1913652e907cdede299b974641a83fbc092102ac41b644"},
+ {file = "safetensors-0.4.5-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f68bf99ea970960a237f416ea394e266e0361895753df06e3e06e6ea7907d98b"},
+ {file = "safetensors-0.4.5-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8158938cf3324172df024da511839d373c40fbfaa83e9abf467174b2910d7b4c"},
+ {file = "safetensors-0.4.5-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:540ce6c4bf6b58cb0fd93fa5f143bc0ee341c93bb4f9287ccd92cf898cc1b0dd"},
+ {file = "safetensors-0.4.5-pp37-pypy37_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:bfeaa1a699c6b9ed514bd15e6a91e74738b71125a9292159e3d6b7f0a53d2cde"},
+ {file = "safetensors-0.4.5-pp37-pypy37_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:01c8f00da537af711979e1b42a69a8ec9e1d7112f208e0e9b8a35d2c381085ef"},
+ {file = "safetensors-0.4.5-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:a0dd565f83b30f2ca79b5d35748d0d99dd4b3454f80e03dfb41f0038e3bdf180"},
+ {file = "safetensors-0.4.5-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:023b6e5facda76989f4cba95a861b7e656b87e225f61811065d5c501f78cdb3f"},
+ {file = "safetensors-0.4.5-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9633b663393d5796f0b60249549371e392b75a0b955c07e9c6f8708a87fc841f"},
+ {file = "safetensors-0.4.5-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:78dd8adfb48716233c45f676d6e48534d34b4bceb50162c13d1f0bdf6f78590a"},
+ {file = "safetensors-0.4.5-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:8e8deb16c4321d61ae72533b8451ec4a9af8656d1c61ff81aa49f966406e4b68"},
+ {file = "safetensors-0.4.5-pp38-pypy38_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:52452fa5999dc50c4decaf0c53aa28371f7f1e0fe5c2dd9129059fbe1e1599c7"},
+ {file = "safetensors-0.4.5-pp38-pypy38_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:d5f23198821e227cfc52d50fa989813513db381255c6d100927b012f0cfec63d"},
+ {file = "safetensors-0.4.5-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:f4beb84b6073b1247a773141a6331117e35d07134b3bb0383003f39971d414bb"},
+ {file = "safetensors-0.4.5-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:68814d599d25ed2fdd045ed54d370d1d03cf35e02dce56de44c651f828fb9b7b"},
+ {file = "safetensors-0.4.5-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f0b6453c54c57c1781292c46593f8a37254b8b99004c68d6c3ce229688931a22"},
+ {file = "safetensors-0.4.5-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:adaa9c6dead67e2dd90d634f89131e43162012479d86e25618e821a03d1eb1dc"},
+ {file = "safetensors-0.4.5-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:73e7d408e9012cd17511b382b43547850969c7979efc2bc353f317abaf23c84c"},
+ {file = "safetensors-0.4.5-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:775409ce0fcc58b10773fdb4221ed1eb007de10fe7adbdf8f5e8a56096b6f0bc"},
+ {file = "safetensors-0.4.5-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:834001bed193e4440c4a3950a31059523ee5090605c907c66808664c932b549c"},
+ {file = "safetensors-0.4.5.tar.gz", hash = "sha256:d73de19682deabb02524b3d5d1f8b3aaba94c72f1bbfc7911b9b9d5d391c0310"},
+]
+
+[package.extras]
+all = ["safetensors[jax]", "safetensors[numpy]", "safetensors[paddlepaddle]", "safetensors[pinned-tf]", "safetensors[quality]", "safetensors[testing]", "safetensors[torch]"]
+dev = ["safetensors[all]"]
+jax = ["flax (>=0.6.3)", "jax (>=0.3.25)", "jaxlib (>=0.3.25)", "safetensors[numpy]"]
+mlx = ["mlx (>=0.0.9)"]
+numpy = ["numpy (>=1.21.6)"]
+paddlepaddle = ["paddlepaddle (>=2.4.1)", "safetensors[numpy]"]
+pinned-tf = ["safetensors[numpy]", "tensorflow (==2.11.0)"]
+quality = ["black (==22.3)", "click (==8.0.4)", "flake8 (>=3.8.3)", "isort (>=5.5.4)"]
+tensorflow = ["safetensors[numpy]", "tensorflow (>=2.11.0)"]
+testing = ["h5py (>=3.7.0)", "huggingface-hub (>=0.12.1)", "hypothesis (>=6.70.2)", "pytest (>=7.2.0)", "pytest-benchmark (>=4.0.0)", "safetensors[numpy]", "setuptools-rust (>=1.5.2)"]
+torch = ["safetensors[numpy]", "torch (>=1.10)"]
+
+[[package]]
+name = "tokenizers"
+version = "0.21.0"
+description = ""
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "tokenizers-0.21.0-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:3c4c93eae637e7d2aaae3d376f06085164e1660f89304c0ab2b1d08a406636b2"},
+ {file = "tokenizers-0.21.0-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:f53ea537c925422a2e0e92a24cce96f6bc5046bbef24a1652a5edc8ba975f62e"},
+ {file = "tokenizers-0.21.0-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b177fb54c4702ef611de0c069d9169f0004233890e0c4c5bd5508ae05abf193"},
+ {file = "tokenizers-0.21.0-cp39-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6b43779a269f4629bebb114e19c3fca0223296ae9fea8bb9a7a6c6fb0657ff8e"},
+ {file = "tokenizers-0.21.0-cp39-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9aeb255802be90acfd363626753fda0064a8df06031012fe7d52fd9a905eb00e"},
+ {file = "tokenizers-0.21.0-cp39-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d8b09dbeb7a8d73ee204a70f94fc06ea0f17dcf0844f16102b9f414f0b7463ba"},
+ {file = "tokenizers-0.21.0-cp39-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:400832c0904f77ce87c40f1a8a27493071282f785724ae62144324f171377273"},
+ {file = "tokenizers-0.21.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e84ca973b3a96894d1707e189c14a774b701596d579ffc7e69debfc036a61a04"},
+ {file = "tokenizers-0.21.0-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:eb7202d231b273c34ec67767378cd04c767e967fda12d4a9e36208a34e2f137e"},
+ {file = "tokenizers-0.21.0-cp39-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:089d56db6782a73a27fd8abf3ba21779f5b85d4a9f35e3b493c7bbcbbf0d539b"},
+ {file = "tokenizers-0.21.0-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:c87ca3dc48b9b1222d984b6b7490355a6fdb411a2d810f6f05977258400ddb74"},
+ {file = "tokenizers-0.21.0-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:4145505a973116f91bc3ac45988a92e618a6f83eb458f49ea0790df94ee243ff"},
+ {file = "tokenizers-0.21.0-cp39-abi3-win32.whl", hash = "sha256:eb1702c2f27d25d9dd5b389cc1f2f51813e99f8ca30d9e25348db6585a97e24a"},
+ {file = "tokenizers-0.21.0-cp39-abi3-win_amd64.whl", hash = "sha256:87841da5a25a3a5f70c102de371db120f41873b854ba65e52bccd57df5a3780c"},
+ {file = "tokenizers-0.21.0.tar.gz", hash = "sha256:ee0894bf311b75b0c03079f33859ae4b2334d675d4e93f5a4132e1eae2834fe4"},
+]
+
+[package.dependencies]
+huggingface-hub = ">=0.16.4,<1.0"
+
+[package.extras]
+dev = ["tokenizers[testing]"]
+docs = ["setuptools-rust", "sphinx", "sphinx-rtd-theme"]
+testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests", "ruff"]
+
+[[package]]
+name = "tomli"
+version = "2.2.1"
+description = "A lil' TOML parser"
+optional = true
+python-versions = ">=3.8"
+files = [
+ {file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"},
+ {file = "tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6"},
+ {file = "tomli-2.2.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ece47d672db52ac607a3d9599a9d48dcb2f2f735c6c2d1f34130085bb12b112a"},
+ {file = "tomli-2.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6972ca9c9cc9f0acaa56a8ca1ff51e7af152a9f87fb64623e31d5c83700080ee"},
+ {file = "tomli-2.2.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c954d2250168d28797dd4e3ac5cf812a406cd5a92674ee4c8f123c889786aa8e"},
+ {file = "tomli-2.2.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8dd28b3e155b80f4d54beb40a441d366adcfe740969820caf156c019fb5c7ec4"},
+ {file = "tomli-2.2.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:e59e304978767a54663af13c07b3d1af22ddee3bb2fb0618ca1593e4f593a106"},
+ {file = "tomli-2.2.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:33580bccab0338d00994d7f16f4c4ec25b776af3ffaac1ed74e0b3fc95e885a8"},
+ {file = "tomli-2.2.1-cp311-cp311-win32.whl", hash = "sha256:465af0e0875402f1d226519c9904f37254b3045fc5084697cefb9bdde1ff99ff"},
+ {file = "tomli-2.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:2d0f2fdd22b02c6d81637a3c95f8cd77f995846af7414c5c4b8d0545afa1bc4b"},
+ {file = "tomli-2.2.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:4a8f6e44de52d5e6c657c9fe83b562f5f4256d8ebbfe4ff922c495620a7f6cea"},
+ {file = "tomli-2.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8d57ca8095a641b8237d5b079147646153d22552f1c637fd3ba7f4b0b29167a8"},
+ {file = "tomli-2.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4e340144ad7ae1533cb897d406382b4b6fede8890a03738ff1683af800d54192"},
+ {file = "tomli-2.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:db2b95f9de79181805df90bedc5a5ab4c165e6ec3fe99f970d0e302f384ad222"},
+ {file = "tomli-2.2.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:40741994320b232529c802f8bc86da4e1aa9f413db394617b9a256ae0f9a7f77"},
+ {file = "tomli-2.2.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:400e720fe168c0f8521520190686ef8ef033fb19fc493da09779e592861b78c6"},
+ {file = "tomli-2.2.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:02abe224de6ae62c19f090f68da4e27b10af2b93213d36cf44e6e1c5abd19fdd"},
+ {file = "tomli-2.2.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:b82ebccc8c8a36f2094e969560a1b836758481f3dc360ce9a3277c65f374285e"},
+ {file = "tomli-2.2.1-cp312-cp312-win32.whl", hash = "sha256:889f80ef92701b9dbb224e49ec87c645ce5df3fa2cc548664eb8a25e03127a98"},
+ {file = "tomli-2.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:7fc04e92e1d624a4a63c76474610238576942d6b8950a2d7f908a340494e67e4"},
+ {file = "tomli-2.2.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:f4039b9cbc3048b2416cc57ab3bda989a6fcf9b36cf8937f01a6e731b64f80d7"},
+ {file = "tomli-2.2.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:286f0ca2ffeeb5b9bd4fcc8d6c330534323ec51b2f52da063b11c502da16f30c"},
+ {file = "tomli-2.2.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a92ef1a44547e894e2a17d24e7557a5e85a9e1d0048b0b5e7541f76c5032cb13"},
+ {file = "tomli-2.2.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9316dc65bed1684c9a98ee68759ceaed29d229e985297003e494aa825ebb0281"},
+ {file = "tomli-2.2.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e85e99945e688e32d5a35c1ff38ed0b3f41f43fad8df0bdf79f72b2ba7bc5272"},
+ {file = "tomli-2.2.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:ac065718db92ca818f8d6141b5f66369833d4a80a9d74435a268c52bdfa73140"},
+ {file = "tomli-2.2.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:d920f33822747519673ee656a4b6ac33e382eca9d331c87770faa3eef562aeb2"},
+ {file = "tomli-2.2.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:a198f10c4d1b1375d7687bc25294306e551bf1abfa4eace6650070a5c1ae2744"},
+ {file = "tomli-2.2.1-cp313-cp313-win32.whl", hash = "sha256:d3f5614314d758649ab2ab3a62d4f2004c825922f9e370b29416484086b264ec"},
+ {file = "tomli-2.2.1-cp313-cp313-win_amd64.whl", hash = "sha256:a38aa0308e754b0e3c67e344754dff64999ff9b513e691d0e786265c93583c69"},
+ {file = "tomli-2.2.1-py3-none-any.whl", hash = "sha256:cb55c73c5f4408779d0cf3eef9f762b9c9f147a77de7b258bef0a5628adc85cc"},
+ {file = "tomli-2.2.1.tar.gz", hash = "sha256:cd45e1dc79c835ce60f7404ec8119f2eb06d38b1deba146f07ced3bbc44505ff"},
+]
+
+[[package]]
+name = "tqdm"
+version = "4.67.1"
+description = "Fast, Extensible Progress Meter"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "tqdm-4.67.1-py3-none-any.whl", hash = "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2"},
+ {file = "tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2"},
+]
+
+[package.dependencies]
+colorama = {version = "*", markers = "platform_system == \"Windows\""}
+
+[package.extras]
+dev = ["nbval", "pytest (>=6)", "pytest-asyncio (>=0.24)", "pytest-cov", "pytest-timeout"]
+discord = ["requests"]
+notebook = ["ipywidgets (>=6)"]
+slack = ["slack-sdk"]
+telegram = ["requests"]
+
+[[package]]
+name = "transformers"
+version = "4.47.0"
+description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow"
+optional = false
+python-versions = ">=3.9.0"
+files = [
+ {file = "transformers-4.47.0-py3-none-any.whl", hash = "sha256:a8e1bafdaae69abdda3cad638fe392e37c86d2ce0ecfcae11d60abb8f949ff4d"},
+ {file = "transformers-4.47.0.tar.gz", hash = "sha256:f8ead7a5a4f6937bb507e66508e5e002dc5930f7b6122a9259c37b099d0f3b19"},
+]
+
+[package.dependencies]
+filelock = "*"
+huggingface-hub = ">=0.24.0,<1.0"
+numpy = ">=1.17"
+packaging = ">=20.0"
+pyyaml = ">=5.1"
+regex = "!=2019.12.17"
+requests = "*"
+safetensors = ">=0.4.1"
+tokenizers = ">=0.21,<0.22"
+tqdm = ">=4.27"
+
+[package.extras]
+accelerate = ["accelerate (>=0.26.0)"]
+agents = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "datasets (!=2.5.0)", "diffusers", "opencv-python", "sentencepiece (>=0.1.91,!=0.1.92)", "torch"]
+all = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "av (==9.2.0)", "codecarbon (==1.2.0)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch", "torchaudio", "torchvision"]
+audio = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
+benchmark = ["optimum-benchmark (>=0.3.0)"]
+codecarbon = ["codecarbon (==1.2.0)"]
+deepspeed = ["accelerate (>=0.26.0)", "deepspeed (>=0.9.3)"]
+deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.26.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk (<=3.8.1)", "optuna", "parameterized", "protobuf", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"]
+dev = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "av (==9.2.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "libcst", "librosa", "nltk (<=3.8.1)", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"]
+dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "libcst", "librosa", "nltk (<=3.8.1)", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.21,<0.22)", "urllib3 (<2.0.0)"]
+dev-torch = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "libcst", "librosa", "nltk (<=3.8.1)", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"]
+flax = ["flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "optax (>=0.0.8,<=0.1.4)", "scipy (<1.13.0)"]
+flax-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
+ftfy = ["ftfy"]
+integrations = ["optuna", "ray[tune] (>=2.7.0)", "sigopt"]
+ja = ["fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "rhoknp (>=1.1.0,<1.3.1)", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)"]
+modelcreation = ["cookiecutter (==1.7.3)"]
+natten = ["natten (>=0.14.6,<0.15.0)"]
+onnx = ["onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "tf2onnx"]
+onnxruntime = ["onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"]
+optuna = ["optuna"]
+quality = ["GitPython (<3.1.19)", "datasets (!=2.5.0)", "isort (>=5.5.4)", "libcst", "rich", "ruff (==0.5.1)", "urllib3 (<2.0.0)"]
+ray = ["ray[tune] (>=2.7.0)"]
+retrieval = ["datasets (!=2.5.0)", "faiss-cpu"]
+ruff = ["ruff (==0.5.1)"]
+sagemaker = ["sagemaker (>=2.31.0)"]
+sentencepiece = ["protobuf", "sentencepiece (>=0.1.91,!=0.1.92)"]
+serving = ["fastapi", "pydantic", "starlette", "uvicorn"]
+sigopt = ["sigopt"]
+sklearn = ["scikit-learn"]
+speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"]
+testing = ["GitPython (<3.1.19)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk (<=3.8.1)", "parameterized", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"]
+tf = ["keras-nlp (>=0.3.1,<0.14.0)", "onnxconverter-common", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx"]
+tf-cpu = ["keras (>2.9,<2.16)", "keras-nlp (>=0.3.1,<0.14.0)", "onnxconverter-common", "tensorflow-cpu (>2.9,<2.16)", "tensorflow-probability (<0.24)", "tensorflow-text (<2.16)", "tf2onnx"]
+tf-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
+tiktoken = ["blobfile", "tiktoken"]
+timm = ["timm (<=1.0.11)"]
+tokenizers = ["tokenizers (>=0.21,<0.22)"]
+torch = ["accelerate (>=0.26.0)", "torch"]
+torch-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"]
+torch-vision = ["Pillow (>=10.0.1,<=15.0)", "torchvision"]
+torchhub = ["filelock", "huggingface-hub (>=0.24.0,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.21,<0.22)", "torch", "tqdm (>=4.27)"]
+video = ["av (==9.2.0)"]
+vision = ["Pillow (>=10.0.1,<=15.0)"]
+
+[[package]]
+name = "typing-extensions"
+version = "4.12.2"
+description = "Backported and Experimental Type Hints for Python 3.8+"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d"},
+ {file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"},
+]
+
+[[package]]
+name = "urllib3"
+version = "2.2.3"
+description = "HTTP library with thread-safe connection pooling, file post, and more."
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "urllib3-2.2.3-py3-none-any.whl", hash = "sha256:ca899ca043dcb1bafa3e262d73aa25c465bfb49e0bd9dd5d59f1d0acba2f8fac"},
+ {file = "urllib3-2.2.3.tar.gz", hash = "sha256:e7d814a81dad81e6caf2ec9fdedb284ecc9c73076b62654547cc64ccdcae26e9"},
+]
+
+[package.extras]
+brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"]
+h2 = ["h2 (>=4,<5)"]
+socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"]
+zstd = ["zstandard (>=0.18.0)"]
+
+[metadata]
+lock-version = "2.0"
+python-versions = ">=3.10,<3.13"
+content-hash = "de2c60fd8f7c54521b2204b3af144b974059361d7892c70eeb325a7fe52bd489"
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000..6e26980
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,69 @@
+[tool.ruff]
+line-length = 119
+target-version = "py38"
+
+
+[tool.ruff.lint]
+preview = true
+extend-select = [
+ "B009", # static getattr
+ "B010", # static setattr
+ "CPY", # Copyright
+ "E", # PEP8 errors
+ "F", # PEP8 formatting
+ "I", # Import sorting
+ "TID251", # Banned API
+ "UP", # Pyupgrade
+ "W", # PEP8 warnings
+]
+ignore = [
+ "E501", # Line length (handled by ruff-format)
+ "E741", # Ambiguous variable name
+ "W605", # Invalid escape sequence
+ "UP007", # X | Y type annotations
+]
+
+
+[tool.ruff.lint.per-file-ignores]
+"__init__.py" = [
+ "F401", # Ignore seemingly unused imports (they're meant for re-export)
+]
+"manim_animations/*" = ["ALL"]
+
+
+[tool.ruff.lint.isort]
+lines-after-imports = 2
+known-first-party = ["accelerate"]
+
+
+[tool.ruff.format]
+exclude = [
+ "manim_animations/*"
+]
+
+
+[tool.ruff.lint.flake8-tidy-imports.banned-api]
+"os.getenv".msg = "Use os.environ instead"
+"os.putenv".msg = "Use os.environ instead"
+"os.unsetenv".msg = "Use os.environ instead"
+
+
+[tool.poetry]
+name = "agents"
+version = "0.1.0"
+description = "Agents : The simplest way to build agentic systems."
+authors = ["Aymeric Roucher"]
+license = "Apache 2.0"
+readme = "README.md"
+
+
+[tool.poetry.dependencies]
+python = ">=3.10,<3.13"
+transformers = ">=4.0.0"
+pytest = {version = ">=8.1.0", optional = true}
+requests = "^2.32.3"
+
+
+[build-system]
+requires = ["poetry-core"]
+build-backend = "poetry.core.masonry.api"
\ No newline at end of file
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000..ebdb781
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,122 @@
+# Copyright 2021 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from setuptools import find_packages, setup
+
+
+extras = {}
+extras["quality"] = [
+ "black ~= 23.1", # hf-doc-builder has a hidden dependency on `black`
+ "hf-doc-builder >= 0.3.0",
+ "ruff ~= 0.6.4",
+]
+extras["docs"] = []
+extras["test_prod"] = ["pytest>=7.2.0,<=8.0.0", "pytest-xdist", "pytest-subtests", "parameterized"]
+extras["test_dev"] = [
+ "datasets",
+ "diffusers",
+ "evaluate",
+ "torchdata>=0.8.0",
+ "torchpippy>=0.2.0",
+ "transformers",
+ "scipy",
+ "scikit-learn",
+ "tqdm",
+ "bitsandbytes",
+ "timm",
+]
+extras["testing"] = extras["test_prod"] + extras["test_dev"]
+extras["deepspeed"] = ["deepspeed"]
+extras["rich"] = ["rich"]
+
+extras["test_trackers"] = ["wandb", "comet-ml", "tensorboard", "dvclive"]
+extras["dev"] = extras["quality"] + extras["testing"] + extras["rich"]
+
+extras["sagemaker"] = [
+ "sagemaker", # boto3 is a required package in sagemaker
+]
+
+setup(
+ name="accelerate",
+ version="1.2.0.dev0",
+ description="Accelerate",
+ long_description=open("README.md", encoding="utf-8").read(),
+ long_description_content_type="text/markdown",
+ keywords="deep learning",
+ license="Apache",
+ author="The HuggingFace team",
+ author_email="zach.mueller@huggingface.co",
+ url="https://github.com/huggingface/accelerate",
+ package_dir={"": "src"},
+ packages=find_packages("src"),
+ entry_points={
+ "console_scripts": [
+ "accelerate=accelerate.commands.accelerate_cli:main",
+ "accelerate-config=accelerate.commands.config:main",
+ "accelerate-estimate-memory=accelerate.commands.estimate:main",
+ "accelerate-launch=accelerate.commands.launch:main",
+ "accelerate-merge-weights=accelerate.commands.merge:main",
+ ]
+ },
+ python_requires=">=3.9.0",
+ install_requires=[
+ "numpy>=1.17,<3.0.0",
+ "packaging>=20.0",
+ "psutil",
+ "pyyaml",
+ "torch>=1.10.0",
+ "huggingface_hub>=0.21.0",
+ "safetensors>=0.4.3",
+ ],
+ extras_require=extras,
+ classifiers=[
+ "Development Status :: 5 - Production/Stable",
+ "Intended Audience :: Developers",
+ "Intended Audience :: Education",
+ "Intended Audience :: Science/Research",
+ "License :: OSI Approved :: Apache Software License",
+ "Operating System :: OS Independent",
+ "Programming Language :: Python :: 3",
+ "Programming Language :: Python :: 3.8",
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
+ ],
+)
+
+# Release checklist
+# 1. Checkout the release branch (for a patch the current release branch, for a new minor version, create one):
+# git checkout -b vXX.xx-release
+# The -b is only necessary for creation (so remove it when doing a patch)
+# 2. Change the version in __init__.py and setup.py to the proper value.
+# 3. Commit these changes with the message: "Release: v"
+# 4. Add a tag in git to mark the release:
+# git tag v -m 'Adds tag v for pypi'
+# Push the tag and release commit to git: git push --tags origin vXX.xx-release
+# 5. Run the following commands in the top-level directory:
+# rm -rf dist
+# rm -rf build
+# python setup.py bdist_wheel
+# python setup.py sdist
+# 6. Upload the package to the pypi test server first:
+# twine upload dist/* -r testpypi
+# 7. Check that you can install it in a virtualenv by running:
+# pip install accelerate
+# pip uninstall accelerate
+# pip install -i https://testpypi.python.org/pypi accelerate
+# accelerate env
+# accelerate test
+# 8. Upload the final version to actual pypi:
+# twine upload dist/* -r pypi
+# 9. Add release notes to the tag in github once everything is looking hunky-dory.
+# 10. Go back to the main branch and update the version in __init__.py, setup.py to the new version ".dev" and push to
+# main.
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/test_agent_types.py b/tests/test_agent_types.py
new file mode 100644
index 0000000..168ab1b
--- /dev/null
+++ b/tests/test_agent_types.py
@@ -0,0 +1,121 @@
+# coding=utf-8
+# Copyright 2023 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import tempfile
+import unittest
+import uuid
+from pathlib import Path
+
+from transformers.agents.agent_types import AgentAudio, AgentImage, AgentText
+from transformers.testing_utils import get_tests_dir, require_soundfile, require_torch, require_vision
+from transformers.utils import is_soundfile_availble, is_torch_available, is_vision_available
+
+
+if is_torch_available():
+ import torch
+
+if is_soundfile_availble():
+ import soundfile as sf
+
+if is_vision_available():
+ from PIL import Image
+
+
+def get_new_path(suffix="") -> str:
+ directory = tempfile.mkdtemp()
+ return os.path.join(directory, str(uuid.uuid4()) + suffix)
+
+
+@require_soundfile
+@require_torch
+class AgentAudioTests(unittest.TestCase):
+ def test_from_tensor(self):
+ tensor = torch.rand(12, dtype=torch.float64) - 0.5
+ agent_type = AgentAudio(tensor)
+ path = str(agent_type.to_string())
+
+ # Ensure that the tensor and the agent_type's tensor are the same
+ self.assertTrue(torch.allclose(tensor, agent_type.to_raw(), atol=1e-4))
+
+ del agent_type
+
+ # Ensure the path remains even after the object deletion
+ self.assertTrue(os.path.exists(path))
+
+ # Ensure that the file contains the same value as the original tensor
+ new_tensor, _ = sf.read(path)
+ self.assertTrue(torch.allclose(tensor, torch.tensor(new_tensor), atol=1e-4))
+
+ def test_from_string(self):
+ tensor = torch.rand(12, dtype=torch.float64) - 0.5
+ path = get_new_path(suffix=".wav")
+ sf.write(path, tensor, 16000)
+
+ agent_type = AgentAudio(path)
+
+ self.assertTrue(torch.allclose(tensor, agent_type.to_raw(), atol=1e-4))
+ self.assertEqual(agent_type.to_string(), path)
+
+
+@require_vision
+@require_torch
+class AgentImageTests(unittest.TestCase):
+ def test_from_tensor(self):
+ tensor = torch.randint(0, 256, (64, 64, 3))
+ agent_type = AgentImage(tensor)
+ path = str(agent_type.to_string())
+
+ # Ensure that the tensor and the agent_type's tensor are the same
+ self.assertTrue(torch.allclose(tensor, agent_type._tensor, atol=1e-4))
+
+ self.assertIsInstance(agent_type.to_raw(), Image.Image)
+
+ # Ensure the path remains even after the object deletion
+ del agent_type
+ self.assertTrue(os.path.exists(path))
+
+ def test_from_string(self):
+ path = Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png"
+ image = Image.open(path)
+ agent_type = AgentImage(path)
+
+ self.assertTrue(path.samefile(agent_type.to_string()))
+ self.assertTrue(image == agent_type.to_raw())
+
+ # Ensure the path remains even after the object deletion
+ del agent_type
+ self.assertTrue(os.path.exists(path))
+
+ def test_from_image(self):
+ path = Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png"
+ image = Image.open(path)
+ agent_type = AgentImage(image)
+
+ self.assertFalse(path.samefile(agent_type.to_string()))
+ self.assertTrue(image == agent_type.to_raw())
+
+ # Ensure the path remains even after the object deletion
+ del agent_type
+ self.assertTrue(os.path.exists(path))
+
+
+class AgentTextTests(unittest.TestCase):
+ def test_from_string(self):
+ string = "Hey!"
+ agent_type = AgentText(string)
+
+ self.assertEqual(string, agent_type.to_string())
+ self.assertEqual(string, agent_type.to_raw())
+ self.assertEqual(string, agent_type)
diff --git a/tests/test_agents.py b/tests/test_agents.py
new file mode 100644
index 0000000..4f24abb
--- /dev/null
+++ b/tests/test_agents.py
@@ -0,0 +1,258 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import tempfile
+import unittest
+import uuid
+
+import pytest
+
+from transformers.agents.agent_types import AgentText
+from transformers.agents.agents import (
+ AgentMaxIterationsError,
+ CodeAgent,
+ ManagedAgent,
+ ReactCodeAgent,
+ ReactJsonAgent,
+ Toolbox,
+)
+from transformers.agents.default_tools import PythonInterpreterTool
+from transformers.testing_utils import require_torch
+
+
+def get_new_path(suffix="") -> str:
+ directory = tempfile.mkdtemp()
+ return os.path.join(directory, str(uuid.uuid4()) + suffix)
+
+
+def fake_react_json_llm(messages, stop_sequences=None, grammar=None) -> str:
+ prompt = str(messages)
+
+ if "special_marker" not in prompt:
+ return """
+Thought: I should multiply 2 by 3.6452. special_marker
+Action:
+{
+ "action": "python_interpreter",
+ "action_input": {"code": "2*3.6452"}
+}
+"""
+ else: # We're at step 2
+ return """
+Thought: I can now answer the initial question
+Action:
+{
+ "action": "final_answer",
+ "action_input": {"answer": "7.2904"}
+}
+"""
+
+
+def fake_react_code_llm(messages, stop_sequences=None, grammar=None) -> str:
+ prompt = str(messages)
+ if "special_marker" not in prompt:
+ return """
+Thought: I should multiply 2 by 3.6452. special_marker
+Code:
+```py
+result = 2**3.6452
+```
+"""
+ else: # We're at step 2
+ return """
+Thought: I can now answer the initial question
+Code:
+```py
+final_answer(7.2904)
+```
+"""
+
+
+def fake_react_code_llm_error(messages, stop_sequences=None) -> str:
+ prompt = str(messages)
+ if "special_marker" not in prompt:
+ return """
+Thought: I should multiply 2 by 3.6452. special_marker
+Code:
+```py
+print = 2
+```
+"""
+ else: # We're at step 2
+ return """
+Thought: I can now answer the initial question
+Code:
+```py
+final_answer("got an error")
+```
+"""
+
+
+def fake_react_code_functiondef(messages, stop_sequences=None) -> str:
+ prompt = str(messages)
+ if "special_marker" not in prompt:
+ return """
+Thought: Let's define the function. special_marker
+Code:
+```py
+import numpy as np
+
+def moving_average(x, w):
+ return np.convolve(x, np.ones(w), 'valid') / w
+```
+"""
+ else: # We're at step 2
+ return """
+Thought: I can now answer the initial question
+Code:
+```py
+x, w = [0, 1, 2, 3, 4, 5], 2
+res = moving_average(x, w)
+final_answer(res)
+```
+"""
+
+
+def fake_code_llm_oneshot(messages, stop_sequences=None, grammar=None) -> str:
+ return """
+Thought: I should multiply 2 by 3.6452. special_marker
+Code:
+```py
+result = python_interpreter(code="2*3.6452")
+final_answer(result)
+```
+"""
+
+
+def fake_code_llm_no_return(messages, stop_sequences=None, grammar=None) -> str:
+ return """
+Thought: I should multiply 2 by 3.6452. special_marker
+Code:
+```py
+result = python_interpreter(code="2*3.6452")
+print(result)
+```
+"""
+
+
+class AgentTests(unittest.TestCase):
+ def test_fake_code_agent(self):
+ agent = CodeAgent(tools=[PythonInterpreterTool()], llm_engine=fake_code_llm_oneshot)
+ output = agent.run("What is 2 multiplied by 3.6452?")
+ assert isinstance(output, str)
+ assert output == "7.2904"
+
+ def test_fake_react_json_agent(self):
+ agent = ReactJsonAgent(tools=[PythonInterpreterTool()], llm_engine=fake_react_json_llm)
+ output = agent.run("What is 2 multiplied by 3.6452?")
+ assert isinstance(output, str)
+ assert output == "7.2904"
+ assert agent.logs[0]["task"] == "What is 2 multiplied by 3.6452?"
+ assert agent.logs[1]["observation"] == "7.2904"
+ assert agent.logs[1]["rationale"].strip() == "Thought: I should multiply 2 by 3.6452. special_marker"
+ assert (
+ agent.logs[2]["llm_output"]
+ == """
+Thought: I can now answer the initial question
+Action:
+{
+ "action": "final_answer",
+ "action_input": {"answer": "7.2904"}
+}
+"""
+ )
+
+ def test_fake_react_code_agent(self):
+ agent = ReactCodeAgent(tools=[PythonInterpreterTool()], llm_engine=fake_react_code_llm)
+ output = agent.run("What is 2 multiplied by 3.6452?")
+ assert isinstance(output, float)
+ assert output == 7.2904
+ assert agent.logs[0]["task"] == "What is 2 multiplied by 3.6452?"
+ assert agent.logs[2]["tool_call"] == {
+ "tool_arguments": "final_answer(7.2904)",
+ "tool_name": "code interpreter",
+ }
+
+ def test_react_code_agent_code_errors_show_offending_lines(self):
+ agent = ReactCodeAgent(tools=[PythonInterpreterTool()], llm_engine=fake_react_code_llm_error)
+ output = agent.run("What is 2 multiplied by 3.6452?")
+ assert isinstance(output, AgentText)
+ assert output == "got an error"
+ assert "Evaluation stopped at line 'print = 2' because of" in str(agent.logs)
+
+ def test_setup_agent_with_empty_toolbox(self):
+ ReactJsonAgent(llm_engine=fake_react_json_llm, tools=[])
+
+ def test_react_fails_max_iterations(self):
+ agent = ReactCodeAgent(
+ tools=[PythonInterpreterTool()],
+ llm_engine=fake_code_llm_no_return, # use this callable because it never ends
+ max_iterations=5,
+ )
+ agent.run("What is 2 multiplied by 3.6452?")
+ assert len(agent.logs) == 7
+ assert type(agent.logs[-1]["error"]) is AgentMaxIterationsError
+
+ @require_torch
+ def test_init_agent_with_different_toolsets(self):
+ toolset_1 = []
+ agent = ReactCodeAgent(tools=toolset_1, llm_engine=fake_react_code_llm)
+ assert (
+ len(agent.toolbox.tools) == 1
+ ) # when no tools are provided, only the final_answer tool is added by default
+
+ toolset_2 = [PythonInterpreterTool(), PythonInterpreterTool()]
+ agent = ReactCodeAgent(tools=toolset_2, llm_engine=fake_react_code_llm)
+ assert (
+ len(agent.toolbox.tools) == 2
+ ) # deduplication of tools, so only one python_interpreter tool is added in addition to final_answer
+
+ toolset_3 = Toolbox(toolset_2)
+ agent = ReactCodeAgent(tools=toolset_3, llm_engine=fake_react_code_llm)
+ assert (
+ len(agent.toolbox.tools) == 2
+ ) # same as previous one, where toolset_3 is an instantiation of previous one
+
+ # check that add_base_tools will not interfere with existing tools
+ with pytest.raises(KeyError) as e:
+ agent = ReactJsonAgent(tools=toolset_3, llm_engine=fake_react_json_llm, add_base_tools=True)
+ assert "already exists in the toolbox" in str(e)
+
+ # check that python_interpreter base tool does not get added to code agents
+ agent = ReactCodeAgent(tools=[], llm_engine=fake_react_code_llm, add_base_tools=True)
+ assert len(agent.toolbox.tools) == 7 # added final_answer tool + 6 base tools (excluding interpreter)
+
+ def test_function_persistence_across_steps(self):
+ agent = ReactCodeAgent(
+ tools=[], llm_engine=fake_react_code_functiondef, max_iterations=2, additional_authorized_imports=["numpy"]
+ )
+ res = agent.run("ok")
+ assert res[0] == 0.5
+
+ def test_init_managed_agent(self):
+ agent = ReactCodeAgent(tools=[], llm_engine=fake_react_code_functiondef)
+ managed_agent = ManagedAgent(agent, name="managed_agent", description="Empty")
+ assert managed_agent.name == "managed_agent"
+ assert managed_agent.description == "Empty"
+
+ def test_agent_description_gets_correctly_inserted_in_system_prompt(self):
+ agent = ReactCodeAgent(tools=[], llm_engine=fake_react_code_functiondef)
+ managed_agent = ManagedAgent(agent, name="managed_agent", description="Empty")
+ manager_agent = ReactCodeAgent(
+ tools=[], llm_engine=fake_react_code_functiondef, managed_agents=[managed_agent]
+ )
+ assert "You can also give requests to team members." not in agent.system_prompt
+ assert "<>" not in agent.system_prompt
+ assert "You can also give requests to team members." in manager_agent.system_prompt
diff --git a/tests/test_document_question_answering.py b/tests/test_document_question_answering.py
new file mode 100644
index 0000000..d135551
--- /dev/null
+++ b/tests/test_document_question_answering.py
@@ -0,0 +1,41 @@
+# coding=utf-8
+# Copyright 2023 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+from datasets import load_dataset
+
+from transformers import load_tool
+
+from .test_tools_common import ToolTesterMixin
+
+
+class DocumentQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin):
+ def setUp(self):
+ self.tool = load_tool("document_question_answering")
+ self.tool.setup()
+
+ def test_exact_match_arg(self):
+ dataset = load_dataset("hf-internal-testing/example-documents", split="test")
+ document = dataset[0]["image"]
+
+ result = self.tool(document, "When is the coffee break?")
+ self.assertEqual(result, "11-14 to 11:39 a.m.")
+
+ def test_exact_match_kwarg(self):
+ dataset = load_dataset("hf-internal-testing/example-documents", split="test")
+ document = dataset[0]["image"]
+
+ self.tool(document=document, question="When is the coffee break?")
diff --git a/tests/test_examples.py b/tests/test_examples.py
new file mode 100644
index 0000000..a16dce5
--- /dev/null
+++ b/tests/test_examples.py
@@ -0,0 +1,290 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import ast
+import os
+import re
+import shutil
+import tempfile
+import unittest
+from pathlib import Path
+from unittest import mock, skip
+
+import torch
+
+from accelerate.test_utils.examples import compare_against_test
+from accelerate.test_utils.testing import (
+ TempDirTestCase,
+ get_launch_command,
+ require_huggingface_suite,
+ require_multi_device,
+ require_multi_gpu,
+ require_non_xpu,
+ require_pippy,
+ require_schedulefree,
+ require_trackers,
+ run_command,
+ slow,
+)
+from accelerate.utils import write_basic_config
+
+
+# DataLoaders built from `test_samples/MRPC` for quick testing
+# Should mock `{script_name}.get_dataloaders` via:
+# @mock.patch("{script_name}.get_dataloaders", mocked_dataloaders)
+
+EXCLUDE_EXAMPLES = [
+ "cross_validation.py",
+ "checkpointing.py",
+ "gradient_accumulation.py",
+ "local_sgd.py",
+ "multi_process_metrics.py",
+ "memory.py",
+ "schedule_free.py",
+ "tracking.py",
+ "automatic_gradient_accumulation.py",
+ "fsdp_with_peak_mem_tracking.py",
+ "deepspeed_with_config_support.py",
+ "megatron_lm_gpt_pretraining.py",
+ "early_stopping.py",
+ "ddp_comm_hook.py",
+ "profiler.py",
+]
+
+
+class ExampleDifferenceTests(unittest.TestCase):
+ """
+ This TestCase checks that all of the `complete_*` scripts contain all of the
+ information found in the `by_feature` scripts, line for line. If one fails,
+ then a complete example does not contain all of the features in the features
+ scripts, and should be updated.
+
+ Each example script should be a single test (such as `test_nlp_example`),
+ and should run `one_complete_example` twice: once with `parser_only=True`,
+ and the other with `parser_only=False`. This is so that when the test
+ failures are returned to the user, they understand if the discrepancy lies in
+ the `main` function, or the `training_loop` function. Otherwise it will be
+ unclear.
+
+ Also, if there are any expected differences between the base script used and
+ `complete_nlp_example.py` (the canonical base script), these should be included in
+ `special_strings`. These would be differences in how something is logged, print statements,
+ etc (such as calls to `Accelerate.log()`)
+ """
+
+ by_feature_path = Path("examples", "by_feature").resolve()
+ examples_path = Path("examples").resolve()
+
+ def one_complete_example(
+ self, complete_file_name: str, parser_only: bool, secondary_filename: str = None, special_strings: list = None
+ ):
+ """
+ Tests a single `complete` example against all of the implemented `by_feature` scripts
+
+ Args:
+ complete_file_name (`str`):
+ The filename of a complete example
+ parser_only (`bool`):
+ Whether to look at the main training function, or the argument parser
+ secondary_filename (`str`, *optional*):
+ A potential secondary base file to strip all script information not relevant for checking,
+ such as "cv_example.py" when testing "complete_cv_example.py"
+ special_strings (`list`, *optional*):
+ A list of strings to potentially remove before checking no differences are left. These should be
+ diffs that are file specific, such as different logging variations between files.
+ """
+ self.maxDiff = None
+ for item in os.listdir(self.by_feature_path):
+ if item not in EXCLUDE_EXAMPLES:
+ item_path = self.by_feature_path / item
+ if item_path.is_file() and item_path.suffix == ".py":
+ with self.subTest(
+ tested_script=complete_file_name,
+ feature_script=item,
+ tested_section="main()" if parser_only else "training_function()",
+ ):
+ diff = compare_against_test(
+ self.examples_path / complete_file_name, item_path, parser_only, secondary_filename
+ )
+ diff = "\n".join(diff)
+ if special_strings is not None:
+ for string in special_strings:
+ diff = diff.replace(string, "")
+ assert diff == ""
+
+ def test_nlp_examples(self):
+ self.one_complete_example("complete_nlp_example.py", True)
+ self.one_complete_example("complete_nlp_example.py", False)
+
+ def test_cv_examples(self):
+ cv_path = (self.examples_path / "cv_example.py").resolve()
+ special_strings = [
+ " " * 16 + "{\n\n",
+ " " * 20 + '"accuracy": eval_metric["accuracy"],\n\n',
+ " " * 20 + '"f1": eval_metric["f1"],\n\n',
+ " " * 20 + '"train_loss": total_loss.item() / len(train_dataloader),\n\n',
+ " " * 20 + '"epoch": epoch,\n\n',
+ " " * 16 + "},\n\n",
+ " " * 16 + "step=epoch,\n",
+ " " * 12,
+ " " * 8 + "for step, batch in enumerate(active_dataloader):\n",
+ ]
+ self.one_complete_example("complete_cv_example.py", True, cv_path, special_strings)
+ self.one_complete_example("complete_cv_example.py", False, cv_path, special_strings)
+
+
+@mock.patch.dict(os.environ, {"TESTING_MOCKED_DATALOADERS": "1"})
+@require_huggingface_suite
+class FeatureExamplesTests(TempDirTestCase):
+ clear_on_setup = False
+
+ @classmethod
+ def setUpClass(cls):
+ super().setUpClass()
+ cls._tmpdir = tempfile.mkdtemp()
+ cls.config_file = Path(cls._tmpdir) / "default_config.yml"
+
+ write_basic_config(save_location=cls.config_file)
+ cls.launch_args = get_launch_command(config_file=cls.config_file)
+
+ @classmethod
+ def tearDownClass(cls):
+ super().tearDownClass()
+ shutil.rmtree(cls._tmpdir)
+
+ def test_checkpointing_by_epoch(self):
+ testargs = f"""
+ examples/by_feature/checkpointing.py
+ --checkpointing_steps epoch
+ --output_dir {self.tmpdir}
+ """.split()
+ run_command(self.launch_args + testargs)
+ assert (self.tmpdir / "epoch_0").exists()
+
+ def test_checkpointing_by_steps(self):
+ testargs = f"""
+ examples/by_feature/checkpointing.py
+ --checkpointing_steps 1
+ --output_dir {self.tmpdir}
+ """.split()
+ _ = run_command(self.launch_args + testargs)
+ assert (self.tmpdir / "step_2").exists()
+
+ def test_load_states_by_epoch(self):
+ testargs = f"""
+ examples/by_feature/checkpointing.py
+ --resume_from_checkpoint {self.tmpdir / "epoch_0"}
+ """.split()
+ output = run_command(self.launch_args + testargs, return_stdout=True)
+ assert "epoch 0:" not in output
+ assert "epoch 1:" in output
+
+ def test_load_states_by_steps(self):
+ testargs = f"""
+ examples/by_feature/checkpointing.py
+ --resume_from_checkpoint {self.tmpdir / "step_2"}
+ """.split()
+ output = run_command(self.launch_args + testargs, return_stdout=True)
+ if torch.cuda.is_available():
+ num_processes = torch.cuda.device_count()
+ else:
+ num_processes = 1
+ if num_processes > 1:
+ assert "epoch 0:" not in output
+ assert "epoch 1:" in output
+ else:
+ assert "epoch 0:" in output
+ assert "epoch 1:" in output
+
+ @slow
+ def test_cross_validation(self):
+ testargs = """
+ examples/by_feature/cross_validation.py
+ --num_folds 2
+ """.split()
+ with mock.patch.dict(os.environ, {"TESTING_MOCKED_DATALOADERS": "0"}):
+ output = run_command(self.launch_args + testargs, return_stdout=True)
+ results = re.findall("({.+})", output)
+ results = [r for r in results if "accuracy" in r][-1]
+ results = ast.literal_eval(results)
+ assert results["accuracy"] >= 0.75
+
+ def test_multi_process_metrics(self):
+ testargs = ["examples/by_feature/multi_process_metrics.py"]
+ run_command(self.launch_args + testargs)
+
+ @require_schedulefree
+ def test_schedulefree(self):
+ testargs = ["examples/by_feature/schedule_free.py"]
+ run_command(self.launch_args + testargs)
+
+ @require_trackers
+ @mock.patch.dict(os.environ, {"WANDB_MODE": "offline", "DVCLIVE_TEST": "true"})
+ def test_tracking(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ testargs = f"""
+ examples/by_feature/tracking.py
+ --with_tracking
+ --project_dir {tmpdir}
+ """.split()
+ run_command(self.launch_args + testargs)
+ assert os.path.exists(os.path.join(tmpdir, "tracking"))
+
+ def test_gradient_accumulation(self):
+ testargs = ["examples/by_feature/gradient_accumulation.py"]
+ run_command(self.launch_args + testargs)
+
+ def test_local_sgd(self):
+ testargs = ["examples/by_feature/local_sgd.py"]
+ run_command(self.launch_args + testargs)
+
+ def test_early_stopping(self):
+ testargs = ["examples/by_feature/early_stopping.py"]
+ run_command(self.launch_args + testargs)
+
+ def test_profiler(self):
+ testargs = ["examples/by_feature/profiler.py"]
+ run_command(self.launch_args + testargs)
+
+ @require_multi_device
+ def test_ddp_comm_hook(self):
+ testargs = ["examples/by_feature/ddp_comm_hook.py", "--ddp_comm_hook", "fp16"]
+ run_command(self.launch_args + testargs)
+
+ @skip(
+ reason="stable-diffusion-v1-5 is no longer available. Potentially `Comfy-Org/stable-diffusion-v1-5-archive` once diffusers support is added."
+ )
+ @require_multi_device
+ def test_distributed_inference_examples_stable_diffusion(self):
+ testargs = ["examples/inference/distributed/stable_diffusion.py"]
+ run_command(self.launch_args + testargs)
+
+ @require_multi_device
+ def test_distributed_inference_examples_phi2(self):
+ testargs = ["examples/inference/distributed/phi2.py"]
+ run_command(self.launch_args + testargs)
+
+ @require_non_xpu
+ @require_pippy
+ @require_multi_gpu
+ def test_pippy_examples_bert(self):
+ testargs = ["examples/inference/pippy/bert.py"]
+ run_command(self.launch_args + testargs)
+
+ @require_non_xpu
+ @require_pippy
+ @require_multi_gpu
+ def test_pippy_examples_gpt2(self):
+ testargs = ["examples/inference/pippy/gpt2.py"]
+ run_command(self.launch_args + testargs)
diff --git a/tests/test_final_answer.py b/tests/test_final_answer.py
new file mode 100644
index 0000000..91bdd65
--- /dev/null
+++ b/tests/test_final_answer.py
@@ -0,0 +1,71 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+from pathlib import Path
+
+import numpy as np
+from PIL import Image
+
+from transformers import is_torch_available
+from transformers.agents.agent_types import AGENT_TYPE_MAPPING
+from transformers.agents.default_tools import FinalAnswerTool
+from transformers.testing_utils import get_tests_dir, require_torch
+
+from .test_tools_common import ToolTesterMixin
+
+
+if is_torch_available():
+ import torch
+
+
+class FinalAnswerToolTester(unittest.TestCase, ToolTesterMixin):
+ def setUp(self):
+ self.inputs = {"answer": "Final answer"}
+ self.tool = FinalAnswerTool()
+
+ def test_exact_match_arg(self):
+ result = self.tool("Final answer")
+ self.assertEqual(result, "Final answer")
+
+ def test_exact_match_kwarg(self):
+ result = self.tool(answer=self.inputs["answer"])
+ self.assertEqual(result, "Final answer")
+
+ def create_inputs(self):
+ inputs_text = {"answer": "Text input"}
+ inputs_image = {
+ "answer": Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png").resize(
+ (512, 512)
+ )
+ }
+ inputs_audio = {"answer": torch.Tensor(np.ones(3000))}
+ return {"string": inputs_text, "image": inputs_image, "audio": inputs_audio}
+
+ @require_torch
+ def test_agent_type_output(self):
+ inputs = self.create_inputs()
+ for input_type, input in inputs.items():
+ output = self.tool(**input)
+ agent_type = AGENT_TYPE_MAPPING[input_type]
+ self.assertTrue(isinstance(output, agent_type))
+
+ @require_torch
+ def test_agent_types_inputs(self):
+ inputs = self.create_inputs()
+ for input_type, input in inputs.items():
+ output = self.tool(**input)
+ agent_type = AGENT_TYPE_MAPPING[input_type]
+ self.assertTrue(isinstance(output, agent_type))
diff --git a/tests/test_image_question_answering.py b/tests/test_image_question_answering.py
new file mode 100644
index 0000000..405933e
--- /dev/null
+++ b/tests/test_image_question_answering.py
@@ -0,0 +1,42 @@
+# coding=utf-8
+# Copyright 2023 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+from pathlib import Path
+
+from transformers import is_vision_available, load_tool
+from transformers.testing_utils import get_tests_dir
+
+from .test_tools_common import ToolTesterMixin
+
+
+if is_vision_available():
+ from PIL import Image
+
+
+class ImageQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin):
+ def setUp(self):
+ self.tool = load_tool("image_question_answering")
+ self.tool.setup()
+
+ def test_exact_match_arg(self):
+ image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
+ result = self.tool(image, "How many cats are sleeping on the couch?")
+ self.assertEqual(result, "2")
+
+ def test_exact_match_kwarg(self):
+ image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
+ result = self.tool(image=image, question="How many cats are sleeping on the couch?")
+ self.assertEqual(result, "2")
diff --git a/tests/test_monitoring.py b/tests/test_monitoring.py
new file mode 100644
index 0000000..c350742
--- /dev/null
+++ b/tests/test_monitoring.py
@@ -0,0 +1,166 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+from transformers.agents.agent_types import AgentImage
+from transformers.agents.agents import AgentError, ReactCodeAgent, ReactJsonAgent
+from transformers.agents.monitoring import stream_to_gradio
+
+
+class MonitoringTester(unittest.TestCase):
+ def test_code_agent_metrics(self):
+ class FakeLLMEngine:
+ def __init__(self):
+ self.last_input_token_count = 10
+ self.last_output_token_count = 20
+
+ def __call__(self, prompt, **kwargs):
+ return """
+Code:
+```py
+final_answer('This is the final answer.')
+```"""
+
+ agent = ReactCodeAgent(
+ tools=[],
+ llm_engine=FakeLLMEngine(),
+ max_iterations=1,
+ )
+
+ agent.run("Fake task")
+
+ self.assertEqual(agent.monitor.total_input_token_count, 10)
+ self.assertEqual(agent.monitor.total_output_token_count, 20)
+
+ def test_json_agent_metrics(self):
+ class FakeLLMEngine:
+ def __init__(self):
+ self.last_input_token_count = 10
+ self.last_output_token_count = 20
+
+ def __call__(self, prompt, **kwargs):
+ return 'Action:{"action": "final_answer", "action_input": {"answer": "image"}}'
+
+ agent = ReactJsonAgent(
+ tools=[],
+ llm_engine=FakeLLMEngine(),
+ max_iterations=1,
+ )
+
+ agent.run("Fake task")
+
+ self.assertEqual(agent.monitor.total_input_token_count, 10)
+ self.assertEqual(agent.monitor.total_output_token_count, 20)
+
+ def test_code_agent_metrics_max_iterations(self):
+ class FakeLLMEngine:
+ def __init__(self):
+ self.last_input_token_count = 10
+ self.last_output_token_count = 20
+
+ def __call__(self, prompt, **kwargs):
+ return "Malformed answer"
+
+ agent = ReactCodeAgent(
+ tools=[],
+ llm_engine=FakeLLMEngine(),
+ max_iterations=1,
+ )
+
+ agent.run("Fake task")
+
+ self.assertEqual(agent.monitor.total_input_token_count, 20)
+ self.assertEqual(agent.monitor.total_output_token_count, 40)
+
+ def test_code_agent_metrics_generation_error(self):
+ class FakeLLMEngine:
+ def __init__(self):
+ self.last_input_token_count = 10
+ self.last_output_token_count = 20
+
+ def __call__(self, prompt, **kwargs):
+ raise AgentError
+
+ agent = ReactCodeAgent(
+ tools=[],
+ llm_engine=FakeLLMEngine(),
+ max_iterations=1,
+ )
+
+ agent.run("Fake task")
+
+ self.assertEqual(agent.monitor.total_input_token_count, 20)
+ self.assertEqual(agent.monitor.total_output_token_count, 40)
+
+ def test_streaming_agent_text_output(self):
+ def dummy_llm_engine(prompt, **kwargs):
+ return """
+Code:
+```py
+final_answer('This is the final answer.')
+```"""
+
+ agent = ReactCodeAgent(
+ tools=[],
+ llm_engine=dummy_llm_engine,
+ max_iterations=1,
+ )
+
+ # Use stream_to_gradio to capture the output
+ outputs = list(stream_to_gradio(agent, task="Test task", test_mode=True))
+
+ self.assertEqual(len(outputs), 3)
+ final_message = outputs[-1]
+ self.assertEqual(final_message.role, "assistant")
+ self.assertIn("This is the final answer.", final_message.content)
+
+ def test_streaming_agent_image_output(self):
+ def dummy_llm_engine(prompt, **kwargs):
+ return 'Action:{"action": "final_answer", "action_input": {"answer": "image"}}'
+
+ agent = ReactJsonAgent(
+ tools=[],
+ llm_engine=dummy_llm_engine,
+ max_iterations=1,
+ )
+
+ # Use stream_to_gradio to capture the output
+ outputs = list(stream_to_gradio(agent, task="Test task", image=AgentImage(value="path.png"), test_mode=True))
+
+ self.assertEqual(len(outputs), 2)
+ final_message = outputs[-1]
+ self.assertEqual(final_message.role, "assistant")
+ self.assertIsInstance(final_message.content, dict)
+ self.assertEqual(final_message.content["path"], "path.png")
+ self.assertEqual(final_message.content["mime_type"], "image/png")
+
+ def test_streaming_with_agent_error(self):
+ def dummy_llm_engine(prompt, **kwargs):
+ raise AgentError("Simulated agent error")
+
+ agent = ReactCodeAgent(
+ tools=[],
+ llm_engine=dummy_llm_engine,
+ max_iterations=1,
+ )
+
+ # Use stream_to_gradio to capture the output
+ outputs = list(stream_to_gradio(agent, task="Test task", test_mode=True))
+
+ self.assertEqual(len(outputs), 3)
+ final_message = outputs[-1]
+ self.assertEqual(final_message.role, "assistant")
+ self.assertIn("Simulated agent error", final_message.content)
diff --git a/tests/test_python_interpreter.py b/tests/test_python_interpreter.py
new file mode 100644
index 0000000..15e5ad7
--- /dev/null
+++ b/tests/test_python_interpreter.py
@@ -0,0 +1,837 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+import pytest
+
+from transformers import load_tool
+from transformers.agents.agent_types import AGENT_TYPE_MAPPING
+from transformers.agents.default_tools import BASE_PYTHON_TOOLS
+from transformers.agents.python_interpreter import InterpreterError, evaluate_python_code
+
+from .test_tools_common import ToolTesterMixin
+
+
+# Fake function we will use as tool
+def add_two(x):
+ return x + 2
+
+
+class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin):
+ def setUp(self):
+ self.tool = load_tool("python_interpreter", authorized_imports=["sqlite3"])
+ self.tool.setup()
+
+ def test_exact_match_arg(self):
+ result = self.tool("(2 / 2) * 4")
+ self.assertEqual(result, "4.0")
+
+ def test_exact_match_kwarg(self):
+ result = self.tool(code="(2 / 2) * 4")
+ self.assertEqual(result, "4.0")
+
+ def test_agent_type_output(self):
+ inputs = ["2 * 2"]
+ output = self.tool(*inputs)
+ output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
+ self.assertTrue(isinstance(output, output_type))
+
+ def test_agent_types_inputs(self):
+ inputs = ["2 * 2"]
+ _inputs = []
+
+ for _input, expected_input in zip(inputs, self.tool.inputs.values()):
+ input_type = expected_input["type"]
+ if isinstance(input_type, list):
+ _inputs.append([AGENT_TYPE_MAPPING[_input_type](_input) for _input_type in input_type])
+ else:
+ _inputs.append(AGENT_TYPE_MAPPING[input_type](_input))
+
+ # Should not raise an error
+ output = self.tool(*inputs)
+ output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
+ self.assertTrue(isinstance(output, output_type))
+
+
+class PythonInterpreterTester(unittest.TestCase):
+ def test_evaluate_assign(self):
+ code = "x = 3"
+ state = {}
+ result = evaluate_python_code(code, {}, state=state)
+ assert result == 3
+ self.assertDictEqual(state, {"x": 3, "print_outputs": ""})
+
+ code = "x = y"
+ state = {"y": 5}
+ result = evaluate_python_code(code, {}, state=state)
+ # evaluate returns the value of the last assignment.
+ assert result == 5
+ self.assertDictEqual(state, {"x": 5, "y": 5, "print_outputs": ""})
+
+ code = "a=1;b=None"
+ result = evaluate_python_code(code, {}, state={})
+ # evaluate returns the value of the last assignment.
+ assert result is None
+
+ def test_assignment_cannot_overwrite_tool(self):
+ code = "print = '3'"
+ with pytest.raises(InterpreterError) as e:
+ evaluate_python_code(code, {"print": print}, state={})
+ assert "Cannot assign to name 'print': doing this would erase the existing tool!" in str(e)
+
+ def test_evaluate_call(self):
+ code = "y = add_two(x)"
+ state = {"x": 3}
+ result = evaluate_python_code(code, {"add_two": add_two}, state=state)
+ assert result == 5
+ self.assertDictEqual(state, {"x": 3, "y": 5, "print_outputs": ""})
+
+ # Should not work without the tool
+ with pytest.raises(InterpreterError) as e:
+ evaluate_python_code(code, {}, state=state)
+ assert "tried to execute add_two" in str(e.value)
+
+ def test_evaluate_constant(self):
+ code = "x = 3"
+ state = {}
+ result = evaluate_python_code(code, {}, state=state)
+ assert result == 3
+ self.assertDictEqual(state, {"x": 3, "print_outputs": ""})
+
+ def test_evaluate_dict(self):
+ code = "test_dict = {'x': x, 'y': add_two(x)}"
+ state = {"x": 3}
+ result = evaluate_python_code(code, {"add_two": add_two}, state=state)
+ self.assertDictEqual(result, {"x": 3, "y": 5})
+ self.assertDictEqual(state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""})
+
+ def test_evaluate_expression(self):
+ code = "x = 3\ny = 5"
+ state = {}
+ result = evaluate_python_code(code, {}, state=state)
+ # evaluate returns the value of the last assignment.
+ assert result == 5
+ self.assertDictEqual(state, {"x": 3, "y": 5, "print_outputs": ""})
+
+ def test_evaluate_f_string(self):
+ code = "text = f'This is x: {x}.'"
+ state = {"x": 3}
+ result = evaluate_python_code(code, {}, state=state)
+ # evaluate returns the value of the last assignment.
+ assert result == "This is x: 3."
+ self.assertDictEqual(state, {"x": 3, "text": "This is x: 3.", "print_outputs": ""})
+
+ def test_evaluate_if(self):
+ code = "if x <= 3:\n y = 2\nelse:\n y = 5"
+ state = {"x": 3}
+ result = evaluate_python_code(code, {}, state=state)
+ # evaluate returns the value of the last assignment.
+ assert result == 2
+ self.assertDictEqual(state, {"x": 3, "y": 2, "print_outputs": ""})
+
+ state = {"x": 8}
+ result = evaluate_python_code(code, {}, state=state)
+ # evaluate returns the value of the last assignment.
+ assert result == 5
+ self.assertDictEqual(state, {"x": 8, "y": 5, "print_outputs": ""})
+
+ def test_evaluate_list(self):
+ code = "test_list = [x, add_two(x)]"
+ state = {"x": 3}
+ result = evaluate_python_code(code, {"add_two": add_two}, state=state)
+ self.assertListEqual(result, [3, 5])
+ self.assertDictEqual(state, {"x": 3, "test_list": [3, 5], "print_outputs": ""})
+
+ def test_evaluate_name(self):
+ code = "y = x"
+ state = {"x": 3}
+ result = evaluate_python_code(code, {}, state=state)
+ assert result == 3
+ self.assertDictEqual(state, {"x": 3, "y": 3, "print_outputs": ""})
+
+ def test_evaluate_subscript(self):
+ code = "test_list = [x, add_two(x)]\ntest_list[1]"
+ state = {"x": 3}
+ result = evaluate_python_code(code, {"add_two": add_two}, state=state)
+ assert result == 5
+ self.assertDictEqual(state, {"x": 3, "test_list": [3, 5], "print_outputs": ""})
+
+ code = "test_dict = {'x': x, 'y': add_two(x)}\ntest_dict['y']"
+ state = {"x": 3}
+ result = evaluate_python_code(code, {"add_two": add_two}, state=state)
+ assert result == 5
+ self.assertDictEqual(state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""})
+
+ code = "vendor = {'revenue': 31000, 'rent': 50312}; vendor['ratio'] = round(vendor['revenue'] / vendor['rent'], 2)"
+ state = {}
+ evaluate_python_code(code, {"min": min, "print": print, "round": round}, state=state)
+ assert state["vendor"] == {"revenue": 31000, "rent": 50312, "ratio": 0.62}
+
+ def test_subscript_string_with_string_index_raises_appropriate_error(self):
+ code = """
+search_results = "[{'title': 'Paris, Ville de Paris, France Weather Forecast | AccuWeather', 'href': 'https://www.accuweather.com/en/fr/paris/623/weather-forecast/623', 'body': 'Get the latest weather forecast for Paris, Ville de Paris, France , including hourly, daily, and 10-day outlooks. AccuWeather provides you with reliable and accurate information on temperature ...'}]"
+for result in search_results:
+ if 'current' in result['title'].lower() or 'temperature' in result['title'].lower():
+ current_weather_url = result['href']
+ print(current_weather_url)
+ break"""
+ with pytest.raises(InterpreterError) as e:
+ evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
+ assert "You're trying to subscript a string with a string index" in e
+
+ def test_evaluate_for(self):
+ code = "x = 0\nfor i in range(3):\n x = i"
+ state = {}
+ result = evaluate_python_code(code, {"range": range}, state=state)
+ assert result == 2
+ self.assertDictEqual(state, {"x": 2, "i": 2, "print_outputs": ""})
+
+ def test_evaluate_binop(self):
+ code = "y + x"
+ state = {"x": 3, "y": 6}
+ result = evaluate_python_code(code, {}, state=state)
+ assert result == 9
+ self.assertDictEqual(state, {"x": 3, "y": 6, "print_outputs": ""})
+
+ def test_recursive_function(self):
+ code = """
+def recur_fibo(n):
+ if n <= 1:
+ return n
+ else:
+ return(recur_fibo(n-1) + recur_fibo(n-2))
+recur_fibo(6)"""
+ result = evaluate_python_code(code, {}, state={})
+ assert result == 8
+
+ def test_evaluate_string_methods(self):
+ code = "'hello'.replace('h', 'o').split('e')"
+ result = evaluate_python_code(code, {}, state={})
+ assert result == ["o", "llo"]
+
+ def test_evaluate_slicing(self):
+ code = "'hello'[1:3][::-1]"
+ result = evaluate_python_code(code, {}, state={})
+ assert result == "le"
+
+ def test_access_attributes(self):
+ code = "integer = 1\nobj_class = integer.__class__\nobj_class"
+ result = evaluate_python_code(code, {}, state={})
+ assert result is int
+
+ def test_list_comprehension(self):
+ code = "sentence = 'THESEAGULL43'\nmeaningful_sentence = '-'.join([char.lower() for char in sentence if char.isalpha()])"
+ result = evaluate_python_code(code, {}, state={})
+ assert result == "t-h-e-s-e-a-g-u-l-l"
+
+ def test_string_indexing(self):
+ code = """text_block = [
+ "THESE",
+ "AGULL"
+]
+sentence = ""
+for block in text_block:
+ for col in range(len(text_block[0])):
+ sentence += block[col]
+ """
+ result = evaluate_python_code(code, {"len": len, "range": range}, state={})
+ assert result == "THESEAGULL"
+
+ def test_tuples(self):
+ code = "x = (1, 2, 3)\nx[1]"
+ result = evaluate_python_code(code, {}, state={})
+ assert result == 2
+
+ code = """
+digits, i = [1, 2, 3], 1
+digits[i], digits[i + 1] = digits[i + 1], digits[i]"""
+ evaluate_python_code(code, {"range": range, "print": print, "int": int}, {})
+
+ code = """
+def calculate_isbn_10_check_digit(number):
+ total = sum((10 - i) * int(digit) for i, digit in enumerate(number))
+ remainder = total % 11
+ check_digit = 11 - remainder
+ if check_digit == 10:
+ return 'X'
+ elif check_digit == 11:
+ return '0'
+ else:
+ return str(check_digit)
+
+# Given 9-digit numbers
+numbers = [
+ "478225952",
+ "643485613",
+ "739394228",
+ "291726859",
+ "875262394",
+ "542617795",
+ "031810713",
+ "957007669",
+ "871467426"
+]
+
+# Calculate check digits for each number
+check_digits = [calculate_isbn_10_check_digit(number) for number in numbers]
+print(check_digits)
+"""
+ state = {}
+ evaluate_python_code(
+ code, {"range": range, "print": print, "sum": sum, "enumerate": enumerate, "int": int, "str": str}, state
+ )
+
+ def test_listcomp(self):
+ code = "x = [i for i in range(3)]"
+ result = evaluate_python_code(code, {"range": range}, state={})
+ assert result == [0, 1, 2]
+
+ def test_break_continue(self):
+ code = "for i in range(10):\n if i == 5:\n break\ni"
+ result = evaluate_python_code(code, {"range": range}, state={})
+ assert result == 5
+
+ code = "for i in range(10):\n if i == 5:\n continue\ni"
+ result = evaluate_python_code(code, {"range": range}, state={})
+ assert result == 9
+
+ def test_call_int(self):
+ code = "import math\nstr(math.ceil(149))"
+ result = evaluate_python_code(code, {"str": lambda x: str(x)}, state={})
+ assert result == "149"
+
+ def test_lambda(self):
+ code = "f = lambda x: x + 2\nf(3)"
+ result = evaluate_python_code(code, {}, state={})
+ assert result == 5
+
+ def test_dictcomp(self):
+ code = "x = {i: i**2 for i in range(3)}"
+ result = evaluate_python_code(code, {"range": range}, state={})
+ assert result == {0: 0, 1: 1, 2: 4}
+
+ code = "{num: name for num, name in {101: 'a', 102: 'b'}.items() if name not in ['a']}"
+ result = evaluate_python_code(code, {"print": print}, state={}, authorized_imports=["pandas"])
+ assert result == {102: "b"}
+
+ code = """
+shifts = {'A': ('6:45', '8:00'), 'B': ('10:00', '11:45')}
+shift_minutes = {worker: ('a', 'b') for worker, (start, end) in shifts.items()}
+"""
+ result = evaluate_python_code(code, {}, state={})
+ assert result == {"A": ("a", "b"), "B": ("a", "b")}
+
+ def test_tuple_assignment(self):
+ code = "a, b = 0, 1\nb"
+ result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
+ assert result == 1
+
+ def test_while(self):
+ code = "i = 0\nwhile i < 3:\n i += 1\ni"
+ result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
+ assert result == 3
+
+ # test infinite loop
+ code = "i = 0\nwhile i < 3:\n i -= 1\ni"
+ with pytest.raises(InterpreterError) as e:
+ evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
+ assert "iterations in While loop exceeded" in str(e)
+
+ # test lazy evaluation
+ code = """
+house_positions = [0, 7, 10, 15, 18, 22, 22]
+i, n, loc = 0, 7, 30
+while i < n and house_positions[i] <= loc:
+ i += 1
+"""
+ state = {}
+ evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state)
+
+ def test_generator(self):
+ code = "a = [1, 2, 3, 4, 5]; b = (i**2 for i in a); list(b)"
+ result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
+ assert result == [1, 4, 9, 16, 25]
+
+ def test_boolops(self):
+ code = """if (not (a > b and a > c)) or d > e:
+ best_city = "Brooklyn"
+else:
+ best_city = "Manhattan"
+ best_city
+ """
+ result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5})
+ assert result == "Brooklyn"
+
+ code = """if d > e and a < b:
+ best_city = "Brooklyn"
+elif d < e and a < b:
+ best_city = "Sacramento"
+else:
+ best_city = "Manhattan"
+ best_city
+ """
+ result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5})
+ assert result == "Sacramento"
+
+ def test_if_conditions(self):
+ code = """char='a'
+if char.isalpha():
+ print('2')"""
+ state = {}
+ evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state)
+ assert state["print_outputs"] == "2\n"
+
+ def test_imports(self):
+ code = "import math\nmath.sqrt(4)"
+ result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
+ assert result == 2.0
+
+ code = "from random import choice, seed\nseed(12)\nchoice(['win', 'lose', 'draw'])"
+ result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
+ assert result == "lose"
+
+ code = "import time, re\ntime.sleep(0.1)"
+ result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
+ assert result is None
+
+ code = "from queue import Queue\nq = Queue()\nq.put(1)\nq.get()"
+ result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
+ assert result == 1
+
+ code = "import itertools\nlist(itertools.islice(range(10), 3))"
+ result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
+ assert result == [0, 1, 2]
+
+ code = "import re\nre.search('a', 'abc').group()"
+ result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
+ assert result == "a"
+
+ code = "import stat\nstat.S_ISREG(0o100644)"
+ result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
+ assert result
+
+ code = "import statistics\nstatistics.mean([1, 2, 3, 4, 4])"
+ result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
+ assert result == 2.8
+
+ code = "import unicodedata\nunicodedata.name('A')"
+ result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
+ assert result == "LATIN CAPITAL LETTER A"
+
+ # Test submodules are handled properly, thus not raising error
+ code = "import numpy.random as rd\nrng = rd.default_rng(12345)\nrng.random()"
+ result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"])
+
+ code = "from numpy.random import default_rng as d_rng\nrng = d_rng(12345)\nrng.random()"
+ result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"])
+
+ def test_additional_imports(self):
+ code = "import numpy as np"
+ evaluate_python_code(code, authorized_imports=["numpy"], state={})
+
+ code = "import numpy.random as rd"
+ evaluate_python_code(code, authorized_imports=["numpy.random"], state={})
+ evaluate_python_code(code, authorized_imports=["numpy"], state={})
+ with pytest.raises(InterpreterError):
+ evaluate_python_code(code, authorized_imports=["random"], state={})
+
+ def test_multiple_comparators(self):
+ code = "0 <= -1 < 4 and 0 <= -5 < 4"
+ result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
+ assert not result
+
+ code = "0 <= 1 < 4 and 0 <= -5 < 4"
+ result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
+ assert not result
+
+ code = "0 <= 4 < 4 and 0 <= 3 < 4"
+ result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
+ assert not result
+
+ code = "0 <= 3 < 4 and 0 <= 3 < 4"
+ result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
+ assert result
+
+ def test_print_output(self):
+ code = "print('Hello world!')\nprint('Ok no one cares')"
+ state = {}
+ result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state)
+ assert result is None
+ assert state["print_outputs"] == "Hello world!\nOk no one cares\n"
+
+ # test print in function
+ code = """
+print("1")
+def function():
+ print("2")
+function()"""
+ state = {}
+ evaluate_python_code(code, {"print": print}, state=state)
+ assert state["print_outputs"] == "1\n2\n"
+
+ def test_tuple_target_in_iterator(self):
+ code = "for a, b in [('Ralf Weikert', 'Austria'), ('Samuel Seungwon Lee', 'South Korea')]:res = a.split()[0]"
+ result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
+ assert result == "Samuel"
+
+ def test_classes(self):
+ code = """
+class Animal:
+ species = "Generic Animal"
+
+ def __init__(self, name, age):
+ self.name = name
+ self.age = age
+
+ def sound(self):
+ return "The animal makes a sound."
+
+ def __str__(self):
+ return f"{self.name}, {self.age} years old"
+
+class Dog(Animal):
+ species = "Canine"
+
+ def __init__(self, name, age, breed):
+ super().__init__(name, age)
+ self.breed = breed
+
+ def sound(self):
+ return "The dog barks."
+
+ def __str__(self):
+ return f"{self.name}, {self.age} years old, {self.breed}"
+
+class Cat(Animal):
+ def sound(self):
+ return "The cat meows."
+
+ def __str__(self):
+ return f"{self.name}, {self.age} years old, {self.species}"
+
+
+# Testing multiple instances
+dog1 = Dog("Fido", 3, "Labrador")
+dog2 = Dog("Buddy", 5, "Golden Retriever")
+
+# Testing method with built-in function
+animals = [dog1, dog2, Cat("Whiskers", 2)]
+num_animals = len(animals)
+
+# Testing exceptions in methods
+class ExceptionTest:
+ def method_that_raises(self):
+ raise ValueError("An error occurred")
+
+try:
+ exc_test = ExceptionTest()
+ exc_test.method_that_raises()
+except ValueError as e:
+ exception_message = str(e)
+
+
+# Collecting results
+dog1_sound = dog1.sound()
+dog1_str = str(dog1)
+dog2_sound = dog2.sound()
+dog2_str = str(dog2)
+cat = Cat("Whiskers", 2)
+cat_sound = cat.sound()
+cat_str = str(cat)
+ """
+ state = {}
+ evaluate_python_code(code, {"print": print, "len": len, "super": super, "str": str, "sum": sum}, state=state)
+
+ # Assert results
+ assert state["dog1_sound"] == "The dog barks."
+ assert state["dog1_str"] == "Fido, 3 years old, Labrador"
+ assert state["dog2_sound"] == "The dog barks."
+ assert state["dog2_str"] == "Buddy, 5 years old, Golden Retriever"
+ assert state["cat_sound"] == "The cat meows."
+ assert state["cat_str"] == "Whiskers, 2 years old, Generic Animal"
+ assert state["num_animals"] == 3
+ assert state["exception_message"] == "An error occurred"
+
+ def test_variable_args(self):
+ code = """
+def var_args_method(self, *args, **kwargs):
+ return sum(args) + sum(kwargs.values())
+
+var_args_method(1, 2, 3, x=4, y=5)
+"""
+ state = {}
+ result = evaluate_python_code(code, {"sum": sum}, state=state)
+ assert result == 15
+
+ def test_exceptions(self):
+ code = """
+def method_that_raises(self):
+ raise ValueError("An error occurred")
+
+try:
+ method_that_raises()
+except ValueError as e:
+ exception_message = str(e)
+ """
+ state = {}
+ evaluate_python_code(code, {"print": print, "len": len, "super": super, "str": str, "sum": sum}, state=state)
+ assert state["exception_message"] == "An error occurred"
+
+ def test_print(self):
+ code = "print(min([1, 2, 3]))"
+ state = {}
+ evaluate_python_code(code, {"min": min, "print": print}, state=state)
+ assert state["print_outputs"] == "1\n"
+
+ def test_types_as_objects(self):
+ code = "type_a = float(2); type_b = str; type_c = int"
+ state = {}
+ result = evaluate_python_code(code, {"float": float, "str": str, "int": int}, state=state)
+ assert result is int
+
+ def test_tuple_id(self):
+ code = """
+food_items = {"apple": 2, "banana": 3, "orange": 1, "pear": 1}
+unique_food_items = [item for item, count in food_item_counts.items() if count == 1]
+"""
+ state = {}
+ result = evaluate_python_code(code, {}, state=state)
+ assert result == ["orange", "pear"]
+
+ def test_nonsimple_augassign(self):
+ code = """
+counts_dict = {'a': 0}
+counts_dict['a'] += 1
+counts_list = [1, 2, 3]
+counts_list += [4, 5, 6]
+
+class Counter:
+ self.count = 0
+
+a = Counter()
+a.count += 1
+"""
+ state = {}
+ evaluate_python_code(code, {}, state=state)
+ assert state["counts_dict"] == {"a": 1}
+ assert state["counts_list"] == [1, 2, 3, 4, 5, 6]
+ assert state["a"].count == 1
+
+ def test_adding_int_to_list_raises_error(self):
+ code = """
+counts = [1, 2, 3]
+counts += 1"""
+ with pytest.raises(InterpreterError) as e:
+ evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
+ assert "Cannot add non-list value 1 to a list." in str(e)
+
+ def test_error_highlights_correct_line_of_code(self):
+ code = """# Ok this is a very long code
+# It has many commented lines
+a = 1
+b = 2
+
+# Here is another piece
+counts = [1, 2, 3]
+counts += 1
+b += 1"""
+ with pytest.raises(InterpreterError) as e:
+ evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
+ assert "Evaluation stopped at line 'counts += 1" in str(e)
+
+ def test_assert(self):
+ code = """
+assert 1 == 1
+assert 1 == 2
+"""
+ with pytest.raises(AssertionError) as e:
+ evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
+ assert "1 == 2" in str(e) and "1 == 1" not in str(e)
+
+ def test_with_context_manager(self):
+ code = """
+class SimpleLock:
+ def __init__(self):
+ self.locked = False
+
+ def __enter__(self):
+ self.locked = True
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ self.locked = False
+
+lock = SimpleLock()
+
+with lock as l:
+ assert l.locked == True
+
+assert lock.locked == False
+ """
+ state = {}
+ tools = {}
+ evaluate_python_code(code, tools, state=state)
+
+ def test_default_arg_in_function(self):
+ code = """
+def f(a, b=333, n=1000):
+ return b + n
+n = f(1, n=667)
+"""
+ res = evaluate_python_code(code, {}, {})
+ assert res == 1000
+
+ def test_set(self):
+ code = """
+S1 = {'a', 'b', 'c'}
+S2 = {'b', 'c', 'd'}
+S3 = S1.difference(S2)
+S4 = S1.intersection(S2)
+"""
+ state = {}
+ evaluate_python_code(code, {}, state=state)
+ assert state["S3"] == {"a"}
+ assert state["S4"] == {"b", "c"}
+
+ def test_break(self):
+ code = """
+i = 0
+
+while True:
+ i+= 1
+ if i==3:
+ break
+
+i"""
+ result = evaluate_python_code(code, {"print": print, "round": round}, state={})
+ assert result == 3
+
+ def test_return(self):
+ # test early returns
+ code = """
+def add_one(n, shift):
+ if True:
+ return n + shift
+ return n
+
+add_one(1, 1)
+"""
+ state = {}
+ result = evaluate_python_code(code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state)
+ assert result == 2
+
+ # test returning None
+ code = """
+def returns_none(a):
+ return
+
+returns_none(1)
+"""
+ state = {}
+ result = evaluate_python_code(code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state)
+ assert result is None
+
+ def test_nested_for_loop(self):
+ code = """
+all_res = []
+for i in range(10):
+ subres = []
+ for j in range(i):
+ subres.append(j)
+ all_res.append(subres)
+
+out = [i for sublist in all_res for i in sublist]
+out[:10]
+"""
+ state = {}
+ result = evaluate_python_code(code, {"print": print, "range": range}, state=state)
+ assert result == [0, 0, 1, 0, 1, 2, 0, 1, 2, 3]
+
+ def test_pandas(self):
+ code = """
+import pandas as pd
+
+df = pd.DataFrame.from_dict({'SetCount': ['5', '4', '5'], 'Quantity': [1, 0, -1]})
+
+df['SetCount'] = pd.to_numeric(df['SetCount'], errors='coerce')
+
+parts_with_5_set_count = df[df['SetCount'] == 5.0]
+parts_with_5_set_count[['Quantity', 'SetCount']].values[1]
+"""
+ state = {}
+ result = evaluate_python_code(code, {}, state=state, authorized_imports=["pandas"])
+ assert np.array_equal(result, [-1, 5])
+
+ code = """
+import pandas as pd
+
+df = pd.DataFrame.from_dict({"AtomicNumber": [111, 104, 105], "ok": [0, 1, 2]})
+print("HH0")
+
+# Filter the DataFrame to get only the rows with outdated atomic numbers
+filtered_df = df.loc[df['AtomicNumber'].isin([104])]
+"""
+ result = evaluate_python_code(code, {"print": print}, state={}, authorized_imports=["pandas"])
+ assert np.array_equal(result.values[0], [104, 1])
+
+ code = """import pandas as pd
+data = pd.DataFrame.from_dict([
+ {"Pclass": 1, "Survived": 1},
+ {"Pclass": 2, "Survived": 0},
+ {"Pclass": 2, "Survived": 1}
+])
+survival_rate_by_class = data.groupby('Pclass')['Survived'].mean()
+"""
+ result = evaluate_python_code(code, {}, state={}, authorized_imports=["pandas"])
+ assert result.values[1] == 0.5
+
+ def test_starred(self):
+ code = """
+from math import radians, sin, cos, sqrt, atan2
+
+def haversine(lat1, lon1, lat2, lon2):
+ R = 6371000 # Radius of the Earth in meters
+ lat1, lon1, lat2, lon2 = map(radians, [lat1, lon1, lat2, lon2])
+ dlat = lat2 - lat1
+ dlon = lon2 - lon1
+ a = sin(dlat / 2) ** 2 + cos(lat1) * cos(lat2) * sin(dlon / 2) ** 2
+ c = 2 * atan2(sqrt(a), sqrt(1 - a))
+ distance = R * c
+ return distance
+
+coords_geneva = (46.1978, 6.1342)
+coords_barcelona = (41.3869, 2.1660)
+
+distance_geneva_barcelona = haversine(*coords_geneva, *coords_barcelona)
+"""
+ result = evaluate_python_code(code, {"print": print, "map": map}, state={}, authorized_imports=["math"])
+ assert round(result, 1) == 622395.4
+
+ def test_for(self):
+ code = """
+shifts = {
+ "Worker A": ("6:45 pm", "8:00 pm"),
+ "Worker B": ("10:00 am", "11:45 am")
+}
+
+shift_intervals = {}
+for worker, (start, end) in shifts.items():
+ shift_intervals[worker] = end
+shift_intervals
+"""
+ result = evaluate_python_code(code, {"print": print, "map": map}, state={})
+ assert result == {"Worker A": "8:00 pm", "Worker B": "11:45 am"}
diff --git a/tests/test_search.py b/tests/test_search.py
new file mode 100644
index 0000000..7e40e3c
--- /dev/null
+++ b/tests/test_search.py
@@ -0,0 +1,30 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+from transformers import load_tool
+
+from .test_tools_common import ToolTesterMixin
+
+
+class DuckDuckGoSearchToolTester(unittest.TestCase, ToolTesterMixin):
+ def setUp(self):
+ self.tool = load_tool("web_search")
+ self.tool.setup()
+
+ def test_exact_match_arg(self):
+ result = self.tool("Agents")
+ assert isinstance(result, list) and isinstance(result[0], dict)
diff --git a/tests/test_speech_to_text.py b/tests/test_speech_to_text.py
new file mode 100644
index 0000000..3d6e9a3
--- /dev/null
+++ b/tests/test_speech_to_text.py
@@ -0,0 +1,36 @@
+# coding=utf-8
+# Copyright 2023 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+
+from transformers import load_tool
+
+from .test_tools_common import ToolTesterMixin
+
+
+class SpeechToTextToolTester(unittest.TestCase, ToolTesterMixin):
+ def setUp(self):
+ self.tool = load_tool("speech_to_text")
+ self.tool.setup()
+
+ def test_exact_match_arg(self):
+ result = self.tool(np.ones(3000))
+ self.assertEqual(result, " Thank you.")
+
+ def test_exact_match_kwarg(self):
+ result = self.tool(audio=np.ones(3000))
+ self.assertEqual(result, " Thank you.")
diff --git a/tests/test_text_to_speech.py b/tests/test_text_to_speech.py
new file mode 100644
index 0000000..d8ed9af
--- /dev/null
+++ b/tests/test_text_to_speech.py
@@ -0,0 +1,50 @@
+# coding=utf-8
+# Copyright 2023 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+from transformers import load_tool
+from transformers.utils import is_torch_available
+
+
+if is_torch_available():
+ import torch
+
+from transformers.testing_utils import require_torch
+
+from .test_tools_common import ToolTesterMixin
+
+
+@require_torch
+class TextToSpeechToolTester(unittest.TestCase, ToolTesterMixin):
+ def setUp(self):
+ self.tool = load_tool("text_to_speech")
+ self.tool.setup()
+
+ def test_exact_match_arg(self):
+ # SpeechT5 isn't deterministic
+ torch.manual_seed(0)
+ result = self.tool("hey")
+ resulting_tensor = result.to_raw()
+ self.assertTrue(len(resulting_tensor.detach().shape) == 1)
+ self.assertTrue(resulting_tensor.detach().shape[0] > 1000)
+
+ def test_exact_match_kwarg(self):
+ # SpeechT5 isn't deterministic
+ torch.manual_seed(0)
+ result = self.tool("hey")
+ resulting_tensor = result.to_raw()
+ self.assertTrue(len(resulting_tensor.detach().shape) == 1)
+ self.assertTrue(resulting_tensor.detach().shape[0] > 1000)
diff --git a/tests/test_tools_common.py b/tests/test_tools_common.py
new file mode 100644
index 0000000..90b6b19
--- /dev/null
+++ b/tests/test_tools_common.py
@@ -0,0 +1,170 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import unittest
+from pathlib import Path
+from typing import Dict, Union
+
+import numpy as np
+import pytest
+
+from transformers import is_torch_available, is_vision_available
+from transformers.agents.agent_types import AGENT_TYPE_MAPPING, AgentAudio, AgentImage, AgentText
+from transformers.agents.tools import Tool, tool
+from transformers.testing_utils import get_tests_dir
+
+
+if is_torch_available():
+ import torch
+
+if is_vision_available():
+ from PIL import Image
+
+
+AUTHORIZED_TYPES = ["string", "boolean", "integer", "number", "audio", "image", "any"]
+
+
+def create_inputs(tool_inputs: Dict[str, Dict[Union[str, type], str]]):
+ inputs = {}
+
+ for input_name, input_desc in tool_inputs.items():
+ input_type = input_desc["type"]
+
+ if input_type == "string":
+ inputs[input_name] = "Text input"
+ elif input_type == "image":
+ inputs[input_name] = Image.open(
+ Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png"
+ ).resize((512, 512))
+ elif input_type == "audio":
+ inputs[input_name] = np.ones(3000)
+ else:
+ raise ValueError(f"Invalid type requested: {input_type}")
+
+ return inputs
+
+
+def output_type(output):
+ if isinstance(output, (str, AgentText)):
+ return "string"
+ elif isinstance(output, (Image.Image, AgentImage)):
+ return "image"
+ elif isinstance(output, (torch.Tensor, AgentAudio)):
+ return "audio"
+ else:
+ raise TypeError(f"Invalid output: {output}")
+
+
+class ToolTesterMixin:
+ def test_inputs_output(self):
+ self.assertTrue(hasattr(self.tool, "inputs"))
+ self.assertTrue(hasattr(self.tool, "output_type"))
+
+ inputs = self.tool.inputs
+ self.assertTrue(isinstance(inputs, dict))
+
+ for _, input_spec in inputs.items():
+ self.assertTrue("type" in input_spec)
+ self.assertTrue("description" in input_spec)
+ self.assertTrue(input_spec["type"] in AUTHORIZED_TYPES)
+ self.assertTrue(isinstance(input_spec["description"], str))
+
+ output_type = self.tool.output_type
+ self.assertTrue(output_type in AUTHORIZED_TYPES)
+
+ def test_common_attributes(self):
+ self.assertTrue(hasattr(self.tool, "description"))
+ self.assertTrue(hasattr(self.tool, "name"))
+ self.assertTrue(hasattr(self.tool, "inputs"))
+ self.assertTrue(hasattr(self.tool, "output_type"))
+
+ def test_agent_type_output(self):
+ inputs = create_inputs(self.tool.inputs)
+ output = self.tool(**inputs)
+ if self.tool.output_type != "any":
+ agent_type = AGENT_TYPE_MAPPING[self.tool.output_type]
+ self.assertTrue(isinstance(output, agent_type))
+
+ def test_agent_types_inputs(self):
+ inputs = create_inputs(self.tool.inputs)
+ _inputs = []
+ for _input, expected_input in zip(inputs, self.tool.inputs.values()):
+ input_type = expected_input["type"]
+ _inputs.append(AGENT_TYPE_MAPPING[input_type](_input))
+
+
+class ToolTests(unittest.TestCase):
+ def test_tool_init_with_decorator(self):
+ @tool
+ def coolfunc(a: str, b: int) -> float:
+ """Cool function
+
+ Args:
+ a: The first argument
+ b: The second one
+ """
+ return b + 2, a
+
+ assert coolfunc.output_type == "number"
+
+ def test_tool_init_vanilla(self):
+ class HFModelDownloadsTool(Tool):
+ name = "model_download_counter"
+ description = """
+ This is a tool that returns the most downloaded model of a given task on the Hugging Face Hub.
+ It returns the name of the checkpoint."""
+
+ inputs = {
+ "task": {
+ "type": "string",
+ "description": "the task category (such as text-classification, depth-estimation, etc)",
+ }
+ }
+ output_type = "integer"
+
+ def forward(self, task):
+ return "best model"
+
+ tool = HFModelDownloadsTool()
+ assert list(tool.inputs.keys())[0] == "task"
+
+ def test_tool_init_decorator_raises_issues(self):
+ with pytest.raises(Exception) as e:
+
+ @tool
+ def coolfunc(a: str, b: int):
+ """Cool function
+
+ Args:
+ a: The first argument
+ b: The second one
+ """
+ return a + b
+
+ assert coolfunc.output_type == "number"
+ assert "Tool return type not found" in str(e)
+
+ with pytest.raises(Exception) as e:
+
+ @tool
+ def coolfunc(a: str, b: int) -> int:
+ """Cool function
+
+ Args:
+ a: The first argument
+ """
+ return b + a
+
+ assert coolfunc.output_type == "number"
+ assert "docstring has no description for the argument" in str(e)
diff --git a/tests/test_translation.py b/tests/test_translation.py
new file mode 100644
index 0000000..9027dd1
--- /dev/null
+++ b/tests/test_translation.py
@@ -0,0 +1,67 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+from transformers import load_tool
+from transformers.agents.agent_types import AGENT_TYPE_MAPPING
+
+from .test_tools_common import ToolTesterMixin, output_type
+
+
+class TranslationToolTester(unittest.TestCase, ToolTesterMixin):
+ def setUp(self):
+ self.tool = load_tool("translation")
+ self.tool.setup()
+ self.remote_tool = load_tool("translation", remote=True)
+
+ def test_exact_match_arg(self):
+ result = self.tool("Hey, what's up?", src_lang="English", tgt_lang="French")
+ self.assertEqual(result, "- Hรฉ, comment รงa va?")
+
+ def test_exact_match_kwarg(self):
+ result = self.tool(text="Hey, what's up?", src_lang="English", tgt_lang="French")
+ self.assertEqual(result, "- Hรฉ, comment รงa va?")
+
+ def test_call(self):
+ inputs = ["Hey, what's up?", "English", "Spanish"]
+ output = self.tool(*inputs)
+
+ self.assertEqual(output_type(output), self.tool.output_type)
+
+ def test_agent_type_output(self):
+ inputs = ["Hey, what's up?", "English", "Spanish"]
+ output = self.tool(*inputs)
+ output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
+ self.assertTrue(isinstance(output, output_type))
+
+ def test_agent_types_inputs(self):
+ example_inputs = {
+ "text": "Hey, what's up?",
+ "src_lang": "English",
+ "tgt_lang": "Spanish",
+ }
+
+ _inputs = []
+ for input_name in example_inputs.keys():
+ example_input = example_inputs[input_name]
+ input_description = self.tool.inputs[input_name]
+ input_type = input_description["type"]
+ _inputs.append(AGENT_TYPE_MAPPING[input_type](example_input))
+
+ # Should not raise an error
+ output = self.tool(**example_inputs)
+ output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
+ self.assertTrue(isinstance(output, output_type))