Pass tests
This commit is contained in:
parent
67deb6808f
commit
1606b9a80c
|
@ -62,7 +62,6 @@ else:
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .agents import (
|
from .agents import (
|
||||||
Agent,
|
Agent,
|
||||||
CodeAgent,
|
|
||||||
ManagedAgent,
|
ManagedAgent,
|
||||||
ReactAgent,
|
ReactAgent,
|
||||||
CodeAgent,
|
CodeAgent,
|
||||||
|
|
|
@ -47,9 +47,6 @@ from .tools import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
HUGGINGFACE_DEFAULT_TOOLS = {}
|
|
||||||
|
|
||||||
|
|
||||||
class AgentError(Exception):
|
class AgentError(Exception):
|
||||||
"""Base class for other agent-related exceptions"""
|
"""Base class for other agent-related exceptions"""
|
||||||
|
|
||||||
|
@ -145,14 +142,18 @@ Here is a list of the team members that you can call:"""
|
||||||
|
|
||||||
|
|
||||||
def format_prompt_with_managed_agents_descriptions(
|
def format_prompt_with_managed_agents_descriptions(
|
||||||
prompt_template, managed_agents=None
|
prompt_template, managed_agents, agent_descriptions_placeholder: Optional[str] = None
|
||||||
) -> str:
|
) -> str:
|
||||||
if managed_agents is not None:
|
if agent_descriptions_placeholder is None:
|
||||||
|
agent_descriptions_placeholder = "{{managed_agents_descriptions}}"
|
||||||
|
if agent_descriptions_placeholder not in prompt_template:
|
||||||
|
raise ValueError(f"Provided prompt template does not contain the managed agents descriptions placeholder '{agent_descriptions_placeholder}'")
|
||||||
|
if len(managed_agents.keys()) > 0:
|
||||||
return prompt_template.replace(
|
return prompt_template.replace(
|
||||||
"<<managed_agents_descriptions>>", show_agents_descriptions(managed_agents)
|
agent_descriptions_placeholder, show_agents_descriptions(managed_agents)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return prompt_template.replace("<<managed_agents_descriptions>>", "")
|
return prompt_template.replace(agent_descriptions_placeholder, "")
|
||||||
|
|
||||||
|
|
||||||
def format_prompt_with_imports(
|
def format_prompt_with_imports(
|
||||||
|
@ -220,12 +221,8 @@ class BaseAgent:
|
||||||
self._toolbox = Toolbox(tools, add_base_tools=add_base_tools)
|
self._toolbox = Toolbox(tools, add_base_tools=add_base_tools)
|
||||||
self._toolbox.add_tool(FinalAnswerTool())
|
self._toolbox.add_tool(FinalAnswerTool())
|
||||||
|
|
||||||
self.system_prompt = format_prompt_with_tools(
|
self.system_prompt = self.initialize_system_prompt()
|
||||||
self._toolbox, self.system_prompt_template, self.tool_description_template
|
print("SYS0:", self.system_prompt)
|
||||||
)
|
|
||||||
self.system_prompt = format_prompt_with_managed_agents_descriptions(
|
|
||||||
self.system_prompt, self.managed_agents
|
|
||||||
)
|
|
||||||
self.prompt_messages = None
|
self.prompt_messages = None
|
||||||
self.logs = []
|
self.logs = []
|
||||||
self.task = None
|
self.task = None
|
||||||
|
@ -353,7 +350,7 @@ class BaseAgent:
|
||||||
split[-2],
|
split[-2],
|
||||||
split[-1],
|
split[-1],
|
||||||
) # NOTE: using indexes starting from the end solves for when you have more than one split_token in the output
|
) # NOTE: using indexes starting from the end solves for when you have more than one split_token in the output
|
||||||
except Exception as e:
|
except Exception:
|
||||||
raise AgentParsingError(
|
raise AgentParsingError(
|
||||||
f"Error: No '{split_token}' token provided in your output.\nYour output:\n{llm_output}\n. Be sure to include an action, prefaced with '{split_token}'!"
|
f"Error: No '{split_token}' token provided in your output.\nYour output:\n{llm_output}\n. Be sure to include an action, prefaced with '{split_token}'!"
|
||||||
)
|
)
|
||||||
|
@ -909,8 +906,9 @@ class CodeAgent(ReactAgent):
|
||||||
self.authorized_imports = list(
|
self.authorized_imports = list(
|
||||||
set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports)
|
set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports)
|
||||||
)
|
)
|
||||||
|
print("SYSS:", self.system_prompt)
|
||||||
self.system_prompt = self.system_prompt.replace(
|
self.system_prompt = self.system_prompt.replace(
|
||||||
"<<authorized_imports>>", str(self.authorized_imports)
|
"{{authorized_imports}}", str(self.authorized_imports)
|
||||||
)
|
)
|
||||||
self.custom_tools = {}
|
self.custom_tools = {}
|
||||||
|
|
||||||
|
|
|
@ -135,6 +135,8 @@ final_answer(caption)
|
||||||
Above example were using tools that might not exist for you. You only have access to these tools:
|
Above example were using tools that might not exist for you. You only have access to these tools:
|
||||||
{{tool_names}}
|
{{tool_names}}
|
||||||
|
|
||||||
|
{{managed_agents_descriptions}}
|
||||||
|
|
||||||
Remember to make sure that variables you use are all defined. In particular don't import packages!
|
Remember to make sure that variables you use are all defined. In particular don't import packages!
|
||||||
Be sure to provide a 'Code:\n```' sequence before the code and '```<end_action>' after, else you will get an error.
|
Be sure to provide a 'Code:\n```' sequence before the code and '```<end_action>' after, else you will get an error.
|
||||||
DO NOT pass the arguments as a dict as in 'answer = ask_search_agent({'query': "What is the place where James Bond lives?"})', but use the arguments directly as in 'answer = ask_search_agent(query="What is the place where James Bond lives?")'.
|
DO NOT pass the arguments as a dict as in 'answer = ask_search_agent({'query': "What is the place where James Bond lives?"})', but use the arguments directly as in 'answer = ask_search_agent(query="What is the place where James Bond lives?")'.
|
||||||
|
@ -260,8 +262,11 @@ Action:
|
||||||
|
|
||||||
|
|
||||||
Above example were using notional tools that might not exist for you. You only have access to these tools:
|
Above example were using notional tools that might not exist for you. You only have access to these tools:
|
||||||
|
|
||||||
{{tool_descriptions}}
|
{{tool_descriptions}}
|
||||||
|
|
||||||
|
{{managed_agents_descriptions}}
|
||||||
|
|
||||||
Here are the rules you should always follow to solve your task:
|
Here are the rules you should always follow to solve your task:
|
||||||
1. ALWAYS provide a 'Thought:' sequence, and an 'Action:' sequence that ends with <end_action>, else you will fail.
|
1. ALWAYS provide a 'Thought:' sequence, and an 'Action:' sequence that ends with <end_action>, else you will fail.
|
||||||
2. Always use the right arguments for the tools. Never use variable names in the 'action_input' field, use the value instead.
|
2. Always use the right arguments for the tools. Never use variable names in the 'action_input' field, use the value instead.
|
||||||
|
@ -355,7 +360,7 @@ Above example were using notional tools that might not exist for you. On top of
|
||||||
|
|
||||||
{{tool_descriptions}}
|
{{tool_descriptions}}
|
||||||
|
|
||||||
<<managed_agents_descriptions>>
|
{{managed_agents_descriptions}}
|
||||||
|
|
||||||
Here are the rules you should always follow to solve your task:
|
Here are the rules you should always follow to solve your task:
|
||||||
1. Always provide a 'Thought:' sequence, and a 'Code:\n```py' sequence ending with '```<end_action>' sequence, else you will fail.
|
1. Always provide a 'Thought:' sequence, and a 'Code:\n```py' sequence ending with '```<end_action>' sequence, else you will fail.
|
||||||
|
|
|
@ -643,8 +643,10 @@ DEFAULT_TOOL_DESCRIPTION_TEMPLATE = """
|
||||||
|
|
||||||
|
|
||||||
def get_tool_description_with_args(
|
def get_tool_description_with_args(
|
||||||
tool: Tool, description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE
|
tool: Tool, description_template: Optional[str] = None
|
||||||
) -> str:
|
) -> str:
|
||||||
|
if description_template is None:
|
||||||
|
description_template = DEFAULT_TOOL_DESCRIPTION_TEMPLATE
|
||||||
compiled_template = compile_jinja_template(description_template)
|
compiled_template = compile_jinja_template(description_template)
|
||||||
rendered = compiled_template.render(
|
rendered = compiled_template.render(
|
||||||
tool=tool,
|
tool=tool,
|
||||||
|
@ -1080,6 +1082,9 @@ def tool(tool_function: Callable) -> Tool:
|
||||||
return SpecificTool()
|
return SpecificTool()
|
||||||
|
|
||||||
|
|
||||||
|
HUGGINGFACE_DEFAULT_TOOLS = {}
|
||||||
|
|
||||||
|
|
||||||
class Toolbox:
|
class Toolbox:
|
||||||
"""
|
"""
|
||||||
The toolbox contains all tools that the agent can perform operations with, as well as a few methods to
|
The toolbox contains all tools that the agent can perform operations with, as well as a few methods to
|
||||||
|
@ -1110,7 +1115,7 @@ class Toolbox:
|
||||||
"""Get all tools currently in the toolbox"""
|
"""Get all tools currently in the toolbox"""
|
||||||
return self._tools
|
return self._tools
|
||||||
|
|
||||||
def show_tool_descriptions(self, tool_description_template: str = None) -> str:
|
def show_tool_descriptions(self, tool_description_template: Optional[str] = None) -> str:
|
||||||
"""
|
"""
|
||||||
Returns the description of all tools in the toolbox
|
Returns the description of all tools in the toolbox
|
||||||
|
|
||||||
|
|
|
@ -1,279 +0,0 @@
|
||||||
#!/usr/bin/env python
|
|
||||||
# coding=utf-8
|
|
||||||
|
|
||||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
|
||||||
from .tools import PipelineTool
|
|
||||||
|
|
||||||
|
|
||||||
LANGUAGE_CODES = {
|
|
||||||
"Acehnese Arabic": "ace_Arab",
|
|
||||||
"Acehnese Latin": "ace_Latn",
|
|
||||||
"Mesopotamian Arabic": "acm_Arab",
|
|
||||||
"Ta'izzi-Adeni Arabic": "acq_Arab",
|
|
||||||
"Tunisian Arabic": "aeb_Arab",
|
|
||||||
"Afrikaans": "afr_Latn",
|
|
||||||
"South Levantine Arabic": "ajp_Arab",
|
|
||||||
"Akan": "aka_Latn",
|
|
||||||
"Amharic": "amh_Ethi",
|
|
||||||
"North Levantine Arabic": "apc_Arab",
|
|
||||||
"Modern Standard Arabic": "arb_Arab",
|
|
||||||
"Modern Standard Arabic Romanized": "arb_Latn",
|
|
||||||
"Najdi Arabic": "ars_Arab",
|
|
||||||
"Moroccan Arabic": "ary_Arab",
|
|
||||||
"Egyptian Arabic": "arz_Arab",
|
|
||||||
"Assamese": "asm_Beng",
|
|
||||||
"Asturian": "ast_Latn",
|
|
||||||
"Awadhi": "awa_Deva",
|
|
||||||
"Central Aymara": "ayr_Latn",
|
|
||||||
"South Azerbaijani": "azb_Arab",
|
|
||||||
"North Azerbaijani": "azj_Latn",
|
|
||||||
"Bashkir": "bak_Cyrl",
|
|
||||||
"Bambara": "bam_Latn",
|
|
||||||
"Balinese": "ban_Latn",
|
|
||||||
"Belarusian": "bel_Cyrl",
|
|
||||||
"Bemba": "bem_Latn",
|
|
||||||
"Bengali": "ben_Beng",
|
|
||||||
"Bhojpuri": "bho_Deva",
|
|
||||||
"Banjar Arabic": "bjn_Arab",
|
|
||||||
"Banjar Latin": "bjn_Latn",
|
|
||||||
"Standard Tibetan": "bod_Tibt",
|
|
||||||
"Bosnian": "bos_Latn",
|
|
||||||
"Buginese": "bug_Latn",
|
|
||||||
"Bulgarian": "bul_Cyrl",
|
|
||||||
"Catalan": "cat_Latn",
|
|
||||||
"Cebuano": "ceb_Latn",
|
|
||||||
"Czech": "ces_Latn",
|
|
||||||
"Chokwe": "cjk_Latn",
|
|
||||||
"Central Kurdish": "ckb_Arab",
|
|
||||||
"Crimean Tatar": "crh_Latn",
|
|
||||||
"Welsh": "cym_Latn",
|
|
||||||
"Danish": "dan_Latn",
|
|
||||||
"German": "deu_Latn",
|
|
||||||
"Southwestern Dinka": "dik_Latn",
|
|
||||||
"Dyula": "dyu_Latn",
|
|
||||||
"Dzongkha": "dzo_Tibt",
|
|
||||||
"Greek": "ell_Grek",
|
|
||||||
"English": "eng_Latn",
|
|
||||||
"Esperanto": "epo_Latn",
|
|
||||||
"Estonian": "est_Latn",
|
|
||||||
"Basque": "eus_Latn",
|
|
||||||
"Ewe": "ewe_Latn",
|
|
||||||
"Faroese": "fao_Latn",
|
|
||||||
"Fijian": "fij_Latn",
|
|
||||||
"Finnish": "fin_Latn",
|
|
||||||
"Fon": "fon_Latn",
|
|
||||||
"French": "fra_Latn",
|
|
||||||
"Friulian": "fur_Latn",
|
|
||||||
"Nigerian Fulfulde": "fuv_Latn",
|
|
||||||
"Scottish Gaelic": "gla_Latn",
|
|
||||||
"Irish": "gle_Latn",
|
|
||||||
"Galician": "glg_Latn",
|
|
||||||
"Guarani": "grn_Latn",
|
|
||||||
"Gujarati": "guj_Gujr",
|
|
||||||
"Haitian Creole": "hat_Latn",
|
|
||||||
"Hausa": "hau_Latn",
|
|
||||||
"Hebrew": "heb_Hebr",
|
|
||||||
"Hindi": "hin_Deva",
|
|
||||||
"Chhattisgarhi": "hne_Deva",
|
|
||||||
"Croatian": "hrv_Latn",
|
|
||||||
"Hungarian": "hun_Latn",
|
|
||||||
"Armenian": "hye_Armn",
|
|
||||||
"Igbo": "ibo_Latn",
|
|
||||||
"Ilocano": "ilo_Latn",
|
|
||||||
"Indonesian": "ind_Latn",
|
|
||||||
"Icelandic": "isl_Latn",
|
|
||||||
"Italian": "ita_Latn",
|
|
||||||
"Javanese": "jav_Latn",
|
|
||||||
"Japanese": "jpn_Jpan",
|
|
||||||
"Kabyle": "kab_Latn",
|
|
||||||
"Jingpho": "kac_Latn",
|
|
||||||
"Kamba": "kam_Latn",
|
|
||||||
"Kannada": "kan_Knda",
|
|
||||||
"Kashmiri Arabic": "kas_Arab",
|
|
||||||
"Kashmiri Devanagari": "kas_Deva",
|
|
||||||
"Georgian": "kat_Geor",
|
|
||||||
"Central Kanuri Arabic": "knc_Arab",
|
|
||||||
"Central Kanuri Latin": "knc_Latn",
|
|
||||||
"Kazakh": "kaz_Cyrl",
|
|
||||||
"Kabiyè": "kbp_Latn",
|
|
||||||
"Kabuverdianu": "kea_Latn",
|
|
||||||
"Khmer": "khm_Khmr",
|
|
||||||
"Kikuyu": "kik_Latn",
|
|
||||||
"Kinyarwanda": "kin_Latn",
|
|
||||||
"Kyrgyz": "kir_Cyrl",
|
|
||||||
"Kimbundu": "kmb_Latn",
|
|
||||||
"Northern Kurdish": "kmr_Latn",
|
|
||||||
"Kikongo": "kon_Latn",
|
|
||||||
"Korean": "kor_Hang",
|
|
||||||
"Lao": "lao_Laoo",
|
|
||||||
"Ligurian": "lij_Latn",
|
|
||||||
"Limburgish": "lim_Latn",
|
|
||||||
"Lingala": "lin_Latn",
|
|
||||||
"Lithuanian": "lit_Latn",
|
|
||||||
"Lombard": "lmo_Latn",
|
|
||||||
"Latgalian": "ltg_Latn",
|
|
||||||
"Luxembourgish": "ltz_Latn",
|
|
||||||
"Luba-Kasai": "lua_Latn",
|
|
||||||
"Ganda": "lug_Latn",
|
|
||||||
"Luo": "luo_Latn",
|
|
||||||
"Mizo": "lus_Latn",
|
|
||||||
"Standard Latvian": "lvs_Latn",
|
|
||||||
"Magahi": "mag_Deva",
|
|
||||||
"Maithili": "mai_Deva",
|
|
||||||
"Malayalam": "mal_Mlym",
|
|
||||||
"Marathi": "mar_Deva",
|
|
||||||
"Minangkabau Arabic ": "min_Arab",
|
|
||||||
"Minangkabau Latin": "min_Latn",
|
|
||||||
"Macedonian": "mkd_Cyrl",
|
|
||||||
"Plateau Malagasy": "plt_Latn",
|
|
||||||
"Maltese": "mlt_Latn",
|
|
||||||
"Meitei Bengali": "mni_Beng",
|
|
||||||
"Halh Mongolian": "khk_Cyrl",
|
|
||||||
"Mossi": "mos_Latn",
|
|
||||||
"Maori": "mri_Latn",
|
|
||||||
"Burmese": "mya_Mymr",
|
|
||||||
"Dutch": "nld_Latn",
|
|
||||||
"Norwegian Nynorsk": "nno_Latn",
|
|
||||||
"Norwegian Bokmål": "nob_Latn",
|
|
||||||
"Nepali": "npi_Deva",
|
|
||||||
"Northern Sotho": "nso_Latn",
|
|
||||||
"Nuer": "nus_Latn",
|
|
||||||
"Nyanja": "nya_Latn",
|
|
||||||
"Occitan": "oci_Latn",
|
|
||||||
"West Central Oromo": "gaz_Latn",
|
|
||||||
"Odia": "ory_Orya",
|
|
||||||
"Pangasinan": "pag_Latn",
|
|
||||||
"Eastern Panjabi": "pan_Guru",
|
|
||||||
"Papiamento": "pap_Latn",
|
|
||||||
"Western Persian": "pes_Arab",
|
|
||||||
"Polish": "pol_Latn",
|
|
||||||
"Portuguese": "por_Latn",
|
|
||||||
"Dari": "prs_Arab",
|
|
||||||
"Southern Pashto": "pbt_Arab",
|
|
||||||
"Ayacucho Quechua": "quy_Latn",
|
|
||||||
"Romanian": "ron_Latn",
|
|
||||||
"Rundi": "run_Latn",
|
|
||||||
"Russian": "rus_Cyrl",
|
|
||||||
"Sango": "sag_Latn",
|
|
||||||
"Sanskrit": "san_Deva",
|
|
||||||
"Santali": "sat_Olck",
|
|
||||||
"Sicilian": "scn_Latn",
|
|
||||||
"Shan": "shn_Mymr",
|
|
||||||
"Sinhala": "sin_Sinh",
|
|
||||||
"Slovak": "slk_Latn",
|
|
||||||
"Slovenian": "slv_Latn",
|
|
||||||
"Samoan": "smo_Latn",
|
|
||||||
"Shona": "sna_Latn",
|
|
||||||
"Sindhi": "snd_Arab",
|
|
||||||
"Somali": "som_Latn",
|
|
||||||
"Southern Sotho": "sot_Latn",
|
|
||||||
"Spanish": "spa_Latn",
|
|
||||||
"Tosk Albanian": "als_Latn",
|
|
||||||
"Sardinian": "srd_Latn",
|
|
||||||
"Serbian": "srp_Cyrl",
|
|
||||||
"Swati": "ssw_Latn",
|
|
||||||
"Sundanese": "sun_Latn",
|
|
||||||
"Swedish": "swe_Latn",
|
|
||||||
"Swahili": "swh_Latn",
|
|
||||||
"Silesian": "szl_Latn",
|
|
||||||
"Tamil": "tam_Taml",
|
|
||||||
"Tatar": "tat_Cyrl",
|
|
||||||
"Telugu": "tel_Telu",
|
|
||||||
"Tajik": "tgk_Cyrl",
|
|
||||||
"Tagalog": "tgl_Latn",
|
|
||||||
"Thai": "tha_Thai",
|
|
||||||
"Tigrinya": "tir_Ethi",
|
|
||||||
"Tamasheq Latin": "taq_Latn",
|
|
||||||
"Tamasheq Tifinagh": "taq_Tfng",
|
|
||||||
"Tok Pisin": "tpi_Latn",
|
|
||||||
"Tswana": "tsn_Latn",
|
|
||||||
"Tsonga": "tso_Latn",
|
|
||||||
"Turkmen": "tuk_Latn",
|
|
||||||
"Tumbuka": "tum_Latn",
|
|
||||||
"Turkish": "tur_Latn",
|
|
||||||
"Twi": "twi_Latn",
|
|
||||||
"Central Atlas Tamazight": "tzm_Tfng",
|
|
||||||
"Uyghur": "uig_Arab",
|
|
||||||
"Ukrainian": "ukr_Cyrl",
|
|
||||||
"Umbundu": "umb_Latn",
|
|
||||||
"Urdu": "urd_Arab",
|
|
||||||
"Northern Uzbek": "uzn_Latn",
|
|
||||||
"Venetian": "vec_Latn",
|
|
||||||
"Vietnamese": "vie_Latn",
|
|
||||||
"Waray": "war_Latn",
|
|
||||||
"Wolof": "wol_Latn",
|
|
||||||
"Xhosa": "xho_Latn",
|
|
||||||
"Eastern Yiddish": "ydd_Hebr",
|
|
||||||
"Yoruba": "yor_Latn",
|
|
||||||
"Yue Chinese": "yue_Hant",
|
|
||||||
"Chinese Simplified": "zho_Hans",
|
|
||||||
"Chinese Traditional": "zho_Hant",
|
|
||||||
"Standard Malay": "zsm_Latn",
|
|
||||||
"Zulu": "zul_Latn",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class TranslationTool(PipelineTool):
|
|
||||||
"""
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```py
|
|
||||||
from transformers.agents import TranslationTool
|
|
||||||
|
|
||||||
translator = TranslationTool()
|
|
||||||
translator("This is a super nice API!", src_lang="English", tgt_lang="French")
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
|
|
||||||
lang_to_code = LANGUAGE_CODES
|
|
||||||
default_checkpoint = "facebook/nllb-200-distilled-600M"
|
|
||||||
description = (
|
|
||||||
"This is a tool that translates text from a language to another."
|
|
||||||
f"Both `src_lang`and `tgt_lang` should belong to this list of languages: {list(lang_to_code.keys())}."
|
|
||||||
)
|
|
||||||
name = "translator"
|
|
||||||
pre_processor_class = AutoTokenizer
|
|
||||||
model_class = AutoModelForSeq2SeqLM
|
|
||||||
|
|
||||||
inputs = {
|
|
||||||
"text": {"type": "string", "description": "The text to translate"},
|
|
||||||
"src_lang": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The language of the text to translate. Written in plain English, such as 'Romanian', or 'Albanian'",
|
|
||||||
},
|
|
||||||
"tgt_lang": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The language for the desired ouput language. Written in plain English, such as 'Romanian', or 'Albanian'",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
output_type = "string"
|
|
||||||
|
|
||||||
def encode(self, text, src_lang, tgt_lang):
|
|
||||||
if src_lang not in self.lang_to_code:
|
|
||||||
raise ValueError(f"{src_lang} is not a supported language.")
|
|
||||||
if tgt_lang not in self.lang_to_code:
|
|
||||||
raise ValueError(f"{tgt_lang} is not a supported language.")
|
|
||||||
src_lang = self.lang_to_code[src_lang]
|
|
||||||
tgt_lang = self.lang_to_code[tgt_lang]
|
|
||||||
return self.pre_processor._build_translation_inputs(
|
|
||||||
text, return_tensors="pt", src_lang=src_lang, tgt_lang=tgt_lang
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, inputs):
|
|
||||||
return self.model.generate(**inputs)
|
|
||||||
|
|
||||||
def decode(self, outputs):
|
|
||||||
return self.post_processor.decode(outputs[0].tolist(), skip_special_tokens=True)
|
|
|
@ -1,10 +1,10 @@
|
||||||
from agents import load_tool, CodeAgent, HfApiEngine
|
from agents import load_tool, CodeAgent, HfApiEngine
|
||||||
|
from agents.search import DuckDuckGoSearchTool
|
||||||
|
|
||||||
# Import tool from Hub
|
# Import tool from Hub
|
||||||
image_generation_tool = load_tool("m-ric/text-to-image", cache=False)
|
image_generation_tool = load_tool("m-ric/text-to-image", cache=False)
|
||||||
|
|
||||||
# Import tool from LangChain
|
# Import tool from LangChain
|
||||||
from agents.search import DuckDuckGoSearchTool
|
|
||||||
|
|
||||||
search_tool = DuckDuckGoSearchTool()
|
search_tool = DuckDuckGoSearchTool()
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
from agents import stream_to_gradio, HfApiEngine, load_tool, CodeAgent
|
from agents import stream_to_gradio, HfApiEngine, load_tool, CodeAgent
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
image_generation_tool = load_tool("m-ric/text-to-image")
|
image_generation_tool = load_tool("m-ric/text-to-image")
|
||||||
|
|
||||||
|
@ -6,9 +7,6 @@ llm_engine = HfApiEngine("Qwen/Qwen2.5-72B-Instruct")
|
||||||
|
|
||||||
agent = CodeAgent(tools=[image_generation_tool], llm_engine=llm_engine)
|
agent = CodeAgent(tools=[image_generation_tool], llm_engine=llm_engine)
|
||||||
|
|
||||||
import gradio as gr
|
|
||||||
|
|
||||||
|
|
||||||
def interact_with_agent(prompt, messages):
|
def interact_with_agent(prompt, messages):
|
||||||
messages.append(gr.ChatMessage(role="user", content=prompt))
|
messages.append(gr.ChatMessage(role="user", content=prompt))
|
||||||
yield messages
|
yield messages
|
||||||
|
|
|
@ -1,14 +1,19 @@
|
||||||
from agents.llm_engine import TransformersEngine
|
from agents import JsonAgent
|
||||||
from agents import CodeAgent, JsonAgent
|
from agents import tool
|
||||||
|
import webbrowser
|
||||||
import requests
|
import requests
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
import random
|
||||||
|
from llama_cpp import Llama
|
||||||
|
from agents import tool
|
||||||
|
import webbrowser
|
||||||
|
from typing import List, Generator, Dict, Any
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
|
||||||
model_repo="andito/SmolLM2-1.7B-Instruct-F16-GGUF"
|
model_repo="andito/SmolLM2-1.7B-Instruct-F16-GGUF"
|
||||||
model_filename="smollm2-1.7b-8k-dpo-f16.gguf"
|
model_filename="smollm2-1.7b-8k-dpo-f16.gguf"
|
||||||
|
|
||||||
import random
|
|
||||||
from llama_cpp import Llama
|
|
||||||
|
|
||||||
model = Llama.from_pretrained(
|
model = Llama.from_pretrained(
|
||||||
repo_id=model_repo,
|
repo_id=model_repo,
|
||||||
|
@ -55,10 +60,6 @@ The example format is as follows. Please make sure the parameter type is correct
|
||||||
... (more tool calls as required)
|
... (more tool calls as required)
|
||||||
]</tool_call>"""
|
]</tool_call>"""
|
||||||
|
|
||||||
|
|
||||||
from agents import tool
|
|
||||||
import webbrowser
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def get_random_number_between(min: int, max: int) -> int:
|
def get_random_number_between(min: int, max: int) -> int:
|
||||||
"""
|
"""
|
||||||
|
@ -110,9 +111,7 @@ def open_webbrowser(url: str) -> str:
|
||||||
webbrowser.open(url)
|
webbrowser.open(url)
|
||||||
return f"I opened {url.replace('https://', '').replace('www.', '')} in the browser."
|
return f"I opened {url.replace('https://', '').replace('www.', '')} in the browser."
|
||||||
|
|
||||||
from typing import List, Dict, Generator, Any
|
‹
|
||||||
import re
|
|
||||||
import json
|
|
||||||
def _parse_response(self, text: str) -> List[Dict[str, Any]]:
|
def _parse_response(self, text: str) -> List[Dict[str, Any]]:
|
||||||
pattern = r"<tool_call>(.*?)</tool_call>"
|
pattern = r"<tool_call>(.*?)</tool_call>"
|
||||||
matches = re.findall(pattern, text, re.DOTALL)
|
matches = re.findall(pattern, text, re.DOTALL)
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from agents import load_tool, CodeAgent, JsonAgent, HfApiEngine
|
from agents import load_tool, CodeAgent, HfApiEngine
|
||||||
from agents.prompts import ONESHOT_CODE_SYSTEM_PROMPT
|
from agents.prompts import ONESHOT_CODE_SYSTEM_PROMPT
|
||||||
|
|
||||||
# Import tool from Hub
|
# Import tool from Hub
|
||||||
|
|
Binary file not shown.
After Width: | Height: | Size: 678 KiB |
|
@ -27,8 +27,6 @@ from transformers.testing_utils import (
|
||||||
)
|
)
|
||||||
from transformers.utils import (
|
from transformers.utils import (
|
||||||
is_soundfile_availble,
|
is_soundfile_availble,
|
||||||
is_torch_available,
|
|
||||||
is_vision_available,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -93,7 +91,7 @@ class AgentImageTests(unittest.TestCase):
|
||||||
self.assertTrue(os.path.exists(path))
|
self.assertTrue(os.path.exists(path))
|
||||||
|
|
||||||
def test_from_string(self):
|
def test_from_string(self):
|
||||||
path = Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png"
|
path = Path(get_tests_dir("fixtures/")) / "000000039769.png"
|
||||||
image = Image.open(path)
|
image = Image.open(path)
|
||||||
agent_type = AgentImage(path)
|
agent_type = AgentImage(path)
|
||||||
|
|
||||||
|
@ -105,7 +103,7 @@ class AgentImageTests(unittest.TestCase):
|
||||||
self.assertTrue(os.path.exists(path))
|
self.assertTrue(os.path.exists(path))
|
||||||
|
|
||||||
def test_from_image(self):
|
def test_from_image(self):
|
||||||
path = Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png"
|
path = Path(get_tests_dir("fixtures/")) / "000000039769.png"
|
||||||
image = Image.open(path)
|
image = Image.open(path)
|
||||||
agent_type = AgentImage(image)
|
agent_type = AgentImage(image)
|
||||||
|
|
||||||
|
|
|
@ -22,7 +22,6 @@ import pytest
|
||||||
from agents.agent_types import AgentText
|
from agents.agent_types import AgentText
|
||||||
from agents.agents import (
|
from agents.agents import (
|
||||||
AgentMaxIterationsError,
|
AgentMaxIterationsError,
|
||||||
CodeAgent,
|
|
||||||
ManagedAgent,
|
ManagedAgent,
|
||||||
CodeAgent,
|
CodeAgent,
|
||||||
JsonAgent,
|
JsonAgent,
|
||||||
|
@ -162,14 +161,14 @@ class AgentTests(unittest.TestCase):
|
||||||
output = agent.run("What is 2 multiplied by 3.6452?")
|
output = agent.run("What is 2 multiplied by 3.6452?")
|
||||||
assert isinstance(output, str)
|
assert isinstance(output, str)
|
||||||
assert output == "7.2904"
|
assert output == "7.2904"
|
||||||
assert agent.logs[0]["task"] == "What is 2 multiplied by 3.6452?"
|
assert agent.logs[1].task == "What is 2 multiplied by 3.6452?"
|
||||||
assert agent.logs[1]["observation"] == "7.2904"
|
assert agent.logs[2].observation == "7.2904"
|
||||||
assert (
|
assert (
|
||||||
agent.logs[1]["rationale"].strip()
|
agent.logs[2].rationale.strip()
|
||||||
== "Thought: I should multiply 2 by 3.6452. special_marker"
|
== "Thought: I should multiply 2 by 3.6452. special_marker"
|
||||||
)
|
)
|
||||||
assert (
|
assert (
|
||||||
agent.logs[2]["llm_output"]
|
agent.logs[3].llm_output
|
||||||
== """
|
== """
|
||||||
Thought: I can now answer the initial question
|
Thought: I can now answer the initial question
|
||||||
Action:
|
Action:
|
||||||
|
@ -187,8 +186,8 @@ Action:
|
||||||
output = agent.run("What is 2 multiplied by 3.6452?")
|
output = agent.run("What is 2 multiplied by 3.6452?")
|
||||||
assert isinstance(output, float)
|
assert isinstance(output, float)
|
||||||
assert output == 7.2904
|
assert output == 7.2904
|
||||||
assert agent.logs[0]["task"] == "What is 2 multiplied by 3.6452?"
|
assert agent.logs[1].task == "What is 2 multiplied by 3.6452?"
|
||||||
assert agent.logs[2]["tool_call"] == {
|
assert agent.logs[3].tool_call == {
|
||||||
"tool_arguments": "final_answer(7.2904)",
|
"tool_arguments": "final_answer(7.2904)",
|
||||||
"tool_name": "code interpreter",
|
"tool_name": "code interpreter",
|
||||||
}
|
}
|
||||||
|
@ -212,10 +211,9 @@ Action:
|
||||||
max_iterations=5,
|
max_iterations=5,
|
||||||
)
|
)
|
||||||
agent.run("What is 2 multiplied by 3.6452?")
|
agent.run("What is 2 multiplied by 3.6452?")
|
||||||
assert len(agent.logs) == 7
|
assert len(agent.logs) == 8
|
||||||
assert type(agent.logs[-1]["error"]) is AgentMaxIterationsError
|
assert type(agent.logs[-1].error) is AgentMaxIterationsError
|
||||||
|
|
||||||
@require_torch
|
|
||||||
def test_init_agent_with_different_toolsets(self):
|
def test_init_agent_with_different_toolsets(self):
|
||||||
toolset_1 = []
|
toolset_1 = []
|
||||||
agent = CodeAgent(tools=toolset_1, llm_engine=fake_react_code_llm)
|
agent = CodeAgent(tools=toolset_1, llm_engine=fake_react_code_llm)
|
||||||
|
@ -245,8 +243,8 @@ Action:
|
||||||
# check that python_interpreter base tool does not get added to code agents
|
# check that python_interpreter base tool does not get added to code agents
|
||||||
agent = CodeAgent(tools=[], llm_engine=fake_react_code_llm, add_base_tools=True)
|
agent = CodeAgent(tools=[], llm_engine=fake_react_code_llm, add_base_tools=True)
|
||||||
assert (
|
assert (
|
||||||
len(agent.toolbox.tools) == 7
|
len(agent.toolbox.tools) == 2
|
||||||
) # added final_answer tool + 6 base tools (excluding interpreter)
|
) # added final_answer tool + search
|
||||||
|
|
||||||
def test_function_persistence_across_steps(self):
|
def test_function_persistence_across_steps(self):
|
||||||
agent = CodeAgent(
|
agent = CodeAgent(
|
||||||
|
@ -273,7 +271,8 @@ Action:
|
||||||
managed_agents=[managed_agent],
|
managed_agents=[managed_agent],
|
||||||
)
|
)
|
||||||
assert "You can also give requests to team members." not in agent.system_prompt
|
assert "You can also give requests to team members." not in agent.system_prompt
|
||||||
assert "<<managed_agents_descriptions>>" not in agent.system_prompt
|
print("ok1")
|
||||||
|
assert "{{managed_agents_descriptions}}" not in agent.system_prompt
|
||||||
assert (
|
assert (
|
||||||
"You can also give requests to team members." in manager_agent.system_prompt
|
"You can also give requests to team members." in manager_agent.system_prompt
|
||||||
)
|
)
|
||||||
|
|
|
@ -20,25 +20,9 @@ import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest import mock, skip
|
from unittest import mock, skip
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import torch
|
from .test_utils import slow, skip, get_launch_command, TempDirTestCase
|
||||||
|
|
||||||
from accelerate.test_utils.examples import compare_against_test
|
|
||||||
from accelerate.test_utils.testing import (
|
|
||||||
TempDirTestCase,
|
|
||||||
get_launch_command,
|
|
||||||
require_huggingface_suite,
|
|
||||||
require_multi_device,
|
|
||||||
require_multi_gpu,
|
|
||||||
require_non_xpu,
|
|
||||||
require_pippy,
|
|
||||||
require_schedulefree,
|
|
||||||
require_trackers,
|
|
||||||
run_command,
|
|
||||||
slow,
|
|
||||||
)
|
|
||||||
from accelerate.utils import write_basic_config
|
|
||||||
|
|
||||||
|
|
||||||
# DataLoaders built from `test_samples/MRPC` for quick testing
|
# DataLoaders built from `test_samples/MRPC` for quick testing
|
||||||
# Should mock `{script_name}.get_dataloaders` via:
|
# Should mock `{script_name}.get_dataloaders` via:
|
||||||
|
@ -62,242 +46,51 @@ EXCLUDE_EXAMPLES = [
|
||||||
"profiler.py",
|
"profiler.py",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
import subprocess
|
||||||
|
|
||||||
class ExampleDifferenceTests(unittest.TestCase):
|
|
||||||
|
class SubprocessCallException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def run_command(command: List[str], return_stdout=False, env=None):
|
||||||
"""
|
"""
|
||||||
This TestCase checks that all of the `complete_*` scripts contain all of the
|
Runs `command` with `subprocess.check_output` and will potentially return the `stdout`. Will also properly capture
|
||||||
information found in the `by_feature` scripts, line for line. If one fails,
|
if an error occured while running `command`
|
||||||
then a complete example does not contain all of the features in the features
|
|
||||||
scripts, and should be updated.
|
|
||||||
|
|
||||||
Each example script should be a single test (such as `test_nlp_example`),
|
|
||||||
and should run `one_complete_example` twice: once with `parser_only=True`,
|
|
||||||
and the other with `parser_only=False`. This is so that when the test
|
|
||||||
failures are returned to the user, they understand if the discrepancy lies in
|
|
||||||
the `main` function, or the `training_loop` function. Otherwise it will be
|
|
||||||
unclear.
|
|
||||||
|
|
||||||
Also, if there are any expected differences between the base script used and
|
|
||||||
`complete_nlp_example.py` (the canonical base script), these should be included in
|
|
||||||
`special_strings`. These would be differences in how something is logged, print statements,
|
|
||||||
etc (such as calls to `Accelerate.log()`)
|
|
||||||
"""
|
"""
|
||||||
|
# Cast every path in `command` to a string
|
||||||
by_feature_path = Path("examples", "by_feature").resolve()
|
for i, c in enumerate(command):
|
||||||
examples_path = Path("examples").resolve()
|
if isinstance(c, Path):
|
||||||
|
command[i] = str(c)
|
||||||
def one_complete_example(
|
if env is None:
|
||||||
self,
|
env = os.environ.copy()
|
||||||
complete_file_name: str,
|
try:
|
||||||
parser_only: bool,
|
output = subprocess.check_output(command, stderr=subprocess.STDOUT, env=env)
|
||||||
secondary_filename: str = None,
|
if return_stdout:
|
||||||
special_strings: list = None,
|
if hasattr(output, "decode"):
|
||||||
):
|
output = output.decode("utf-8")
|
||||||
"""
|
return output
|
||||||
Tests a single `complete` example against all of the implemented `by_feature` scripts
|
except subprocess.CalledProcessError as e:
|
||||||
|
raise SubprocessCallException(
|
||||||
Args:
|
f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}"
|
||||||
complete_file_name (`str`):
|
) from e
|
||||||
The filename of a complete example
|
|
||||||
parser_only (`bool`):
|
|
||||||
Whether to look at the main training function, or the argument parser
|
|
||||||
secondary_filename (`str`, *optional*):
|
|
||||||
A potential secondary base file to strip all script information not relevant for checking,
|
|
||||||
such as "cv_example.py" when testing "complete_cv_example.py"
|
|
||||||
special_strings (`list`, *optional*):
|
|
||||||
A list of strings to potentially remove before checking no differences are left. These should be
|
|
||||||
diffs that are file specific, such as different logging variations between files.
|
|
||||||
"""
|
|
||||||
self.maxDiff = None
|
|
||||||
for item in os.listdir(self.by_feature_path):
|
|
||||||
if item not in EXCLUDE_EXAMPLES:
|
|
||||||
item_path = self.by_feature_path / item
|
|
||||||
if item_path.is_file() and item_path.suffix == ".py":
|
|
||||||
with self.subTest(
|
|
||||||
tested_script=complete_file_name,
|
|
||||||
feature_script=item,
|
|
||||||
tested_section="main()"
|
|
||||||
if parser_only
|
|
||||||
else "training_function()",
|
|
||||||
):
|
|
||||||
diff = compare_against_test(
|
|
||||||
self.examples_path / complete_file_name,
|
|
||||||
item_path,
|
|
||||||
parser_only,
|
|
||||||
secondary_filename,
|
|
||||||
)
|
|
||||||
diff = "\n".join(diff)
|
|
||||||
if special_strings is not None:
|
|
||||||
for string in special_strings:
|
|
||||||
diff = diff.replace(string, "")
|
|
||||||
assert diff == ""
|
|
||||||
|
|
||||||
def test_nlp_examples(self):
|
|
||||||
self.one_complete_example("complete_nlp_example.py", True)
|
|
||||||
self.one_complete_example("complete_nlp_example.py", False)
|
|
||||||
|
|
||||||
def test_cv_examples(self):
|
|
||||||
cv_path = (self.examples_path / "cv_example.py").resolve()
|
|
||||||
special_strings = [
|
|
||||||
" " * 16 + "{\n\n",
|
|
||||||
" " * 20 + '"accuracy": eval_metric["accuracy"],\n\n',
|
|
||||||
" " * 20 + '"f1": eval_metric["f1"],\n\n',
|
|
||||||
" " * 20 + '"train_loss": total_loss.item() / len(train_dataloader),\n\n',
|
|
||||||
" " * 20 + '"epoch": epoch,\n\n',
|
|
||||||
" " * 16 + "},\n\n",
|
|
||||||
" " * 16 + "step=epoch,\n",
|
|
||||||
" " * 12,
|
|
||||||
" " * 8 + "for step, batch in enumerate(active_dataloader):\n",
|
|
||||||
]
|
|
||||||
self.one_complete_example(
|
|
||||||
"complete_cv_example.py", True, cv_path, special_strings
|
|
||||||
)
|
|
||||||
self.one_complete_example(
|
|
||||||
"complete_cv_example.py", False, cv_path, special_strings
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@mock.patch.dict(os.environ, {"TESTING_MOCKED_DATALOADERS": "1"})
|
class ExamplesTests(TempDirTestCase):
|
||||||
@require_huggingface_suite
|
|
||||||
class FeatureExamplesTests(TempDirTestCase):
|
|
||||||
clear_on_setup = False
|
clear_on_setup = False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
super().setUpClass()
|
super().setUpClass()
|
||||||
cls._tmpdir = tempfile.mkdtemp()
|
cls._tmpdir = tempfile.mkdtemp()
|
||||||
cls.config_file = Path(cls._tmpdir) / "default_config.yml"
|
cls.launch_args = ["python3"]
|
||||||
|
|
||||||
write_basic_config(save_location=cls.config_file)
|
|
||||||
cls.launch_args = get_launch_command(config_file=cls.config_file)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls):
|
def tearDownClass(cls):
|
||||||
super().tearDownClass()
|
super().tearDownClass()
|
||||||
shutil.rmtree(cls._tmpdir)
|
shutil.rmtree(cls._tmpdir)
|
||||||
|
|
||||||
def test_checkpointing_by_epoch(self):
|
|
||||||
testargs = f"""
|
def test_oneshot(self):
|
||||||
examples/by_feature/checkpointing.py
|
testargs = ["examples/oneshot.py"]
|
||||||
--checkpointing_steps epoch
|
|
||||||
--output_dir {self.tmpdir}
|
|
||||||
""".split()
|
|
||||||
run_command(self.launch_args + testargs)
|
|
||||||
assert (self.tmpdir / "epoch_0").exists()
|
|
||||||
|
|
||||||
def test_checkpointing_by_steps(self):
|
|
||||||
testargs = f"""
|
|
||||||
examples/by_feature/checkpointing.py
|
|
||||||
--checkpointing_steps 1
|
|
||||||
--output_dir {self.tmpdir}
|
|
||||||
""".split()
|
|
||||||
_ = run_command(self.launch_args + testargs)
|
|
||||||
assert (self.tmpdir / "step_2").exists()
|
|
||||||
|
|
||||||
def test_load_states_by_epoch(self):
|
|
||||||
testargs = f"""
|
|
||||||
examples/by_feature/checkpointing.py
|
|
||||||
--resume_from_checkpoint {self.tmpdir / "epoch_0"}
|
|
||||||
""".split()
|
|
||||||
output = run_command(self.launch_args + testargs, return_stdout=True)
|
|
||||||
assert "epoch 0:" not in output
|
|
||||||
assert "epoch 1:" in output
|
|
||||||
|
|
||||||
def test_load_states_by_steps(self):
|
|
||||||
testargs = f"""
|
|
||||||
examples/by_feature/checkpointing.py
|
|
||||||
--resume_from_checkpoint {self.tmpdir / "step_2"}
|
|
||||||
""".split()
|
|
||||||
output = run_command(self.launch_args + testargs, return_stdout=True)
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
num_processes = torch.cuda.device_count()
|
|
||||||
else:
|
|
||||||
num_processes = 1
|
|
||||||
if num_processes > 1:
|
|
||||||
assert "epoch 0:" not in output
|
|
||||||
assert "epoch 1:" in output
|
|
||||||
else:
|
|
||||||
assert "epoch 0:" in output
|
|
||||||
assert "epoch 1:" in output
|
|
||||||
|
|
||||||
@slow
|
|
||||||
def test_cross_validation(self):
|
|
||||||
testargs = """
|
|
||||||
examples/by_feature/cross_validation.py
|
|
||||||
--num_folds 2
|
|
||||||
""".split()
|
|
||||||
with mock.patch.dict(os.environ, {"TESTING_MOCKED_DATALOADERS": "0"}):
|
|
||||||
output = run_command(self.launch_args + testargs, return_stdout=True)
|
|
||||||
results = re.findall("({.+})", output)
|
|
||||||
results = [r for r in results if "accuracy" in r][-1]
|
|
||||||
results = ast.literal_eval(results)
|
|
||||||
assert results["accuracy"] >= 0.75
|
|
||||||
|
|
||||||
def test_multi_process_metrics(self):
|
|
||||||
testargs = ["examples/by_feature/multi_process_metrics.py"]
|
|
||||||
run_command(self.launch_args + testargs)
|
|
||||||
|
|
||||||
@require_schedulefree
|
|
||||||
def test_schedulefree(self):
|
|
||||||
testargs = ["examples/by_feature/schedule_free.py"]
|
|
||||||
run_command(self.launch_args + testargs)
|
|
||||||
|
|
||||||
@require_trackers
|
|
||||||
@mock.patch.dict(os.environ, {"WANDB_MODE": "offline", "DVCLIVE_TEST": "true"})
|
|
||||||
def test_tracking(self):
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
|
||||||
testargs = f"""
|
|
||||||
examples/by_feature/tracking.py
|
|
||||||
--with_tracking
|
|
||||||
--project_dir {tmpdir}
|
|
||||||
""".split()
|
|
||||||
run_command(self.launch_args + testargs)
|
|
||||||
assert os.path.exists(os.path.join(tmpdir, "tracking"))
|
|
||||||
|
|
||||||
def test_gradient_accumulation(self):
|
|
||||||
testargs = ["examples/by_feature/gradient_accumulation.py"]
|
|
||||||
run_command(self.launch_args + testargs)
|
|
||||||
|
|
||||||
def test_local_sgd(self):
|
|
||||||
testargs = ["examples/by_feature/local_sgd.py"]
|
|
||||||
run_command(self.launch_args + testargs)
|
|
||||||
|
|
||||||
def test_early_stopping(self):
|
|
||||||
testargs = ["examples/by_feature/early_stopping.py"]
|
|
||||||
run_command(self.launch_args + testargs)
|
|
||||||
|
|
||||||
def test_profiler(self):
|
|
||||||
testargs = ["examples/by_feature/profiler.py"]
|
|
||||||
run_command(self.launch_args + testargs)
|
|
||||||
|
|
||||||
@require_multi_device
|
|
||||||
def test_ddp_comm_hook(self):
|
|
||||||
testargs = ["examples/by_feature/ddp_comm_hook.py", "--ddp_comm_hook", "fp16"]
|
|
||||||
run_command(self.launch_args + testargs)
|
|
||||||
|
|
||||||
@skip(
|
|
||||||
reason="stable-diffusion-v1-5 is no longer available. Potentially `Comfy-Org/stable-diffusion-v1-5-archive` once diffusers support is added."
|
|
||||||
)
|
|
||||||
@require_multi_device
|
|
||||||
def test_distributed_inference_examples_stable_diffusion(self):
|
|
||||||
testargs = ["examples/inference/distributed/stable_diffusion.py"]
|
|
||||||
run_command(self.launch_args + testargs)
|
|
||||||
|
|
||||||
@require_multi_device
|
|
||||||
def test_distributed_inference_examples_phi2(self):
|
|
||||||
testargs = ["examples/inference/distributed/phi2.py"]
|
|
||||||
run_command(self.launch_args + testargs)
|
|
||||||
|
|
||||||
@require_non_xpu
|
|
||||||
@require_pippy
|
|
||||||
@require_multi_gpu
|
|
||||||
def test_pippy_examples_bert(self):
|
|
||||||
testargs = ["examples/inference/pippy/bert.py"]
|
|
||||||
run_command(self.launch_args + testargs)
|
|
||||||
|
|
||||||
@require_non_xpu
|
|
||||||
@require_pippy
|
|
||||||
@require_multi_gpu
|
|
||||||
def test_pippy_examples_gpt2(self):
|
|
||||||
testargs = ["examples/inference/pippy/gpt2.py"]
|
|
||||||
run_command(self.launch_args + testargs)
|
run_command(self.launch_args + testargs)
|
||||||
|
|
|
@ -48,7 +48,7 @@ class FinalAnswerToolTester(unittest.TestCase, ToolTesterMixin):
|
||||||
inputs_text = {"answer": "Text input"}
|
inputs_text = {"answer": "Text input"}
|
||||||
inputs_image = {
|
inputs_image = {
|
||||||
"answer": Image.open(
|
"answer": Image.open(
|
||||||
Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png"
|
Path(get_tests_dir("fixtures")) / "000000039769.png"
|
||||||
).resize((512, 512))
|
).resize((512, 512))
|
||||||
}
|
}
|
||||||
inputs_audio = {"answer": torch.Tensor(np.ones(3000))}
|
inputs_audio = {"answer": torch.Tensor(np.ones(3000))}
|
||||||
|
|
|
@ -50,7 +50,7 @@ def create_inputs(tool_inputs: Dict[str, Dict[Union[str, type], str]]):
|
||||||
inputs[input_name] = "Text input"
|
inputs[input_name] = "Text input"
|
||||||
elif input_type == "image":
|
elif input_type == "image":
|
||||||
inputs[input_name] = Image.open(
|
inputs[input_name] = Image.open(
|
||||||
Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png"
|
Path(get_tests_dir("fixtures")) / "000000039769.png"
|
||||||
).resize((512, 512))
|
).resize((512, 512))
|
||||||
elif input_type == "audio":
|
elif input_type == "audio":
|
||||||
inputs[input_name] = np.ones(3000)
|
inputs[input_name] = np.ones(3000)
|
||||||
|
|
|
@ -1,69 +0,0 @@
|
||||||
# coding=utf-8
|
|
||||||
# Copyright 2024 HuggingFace Inc.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
from transformers import load_tool
|
|
||||||
from transformers.agents.agent_types import AGENT_TYPE_MAPPING
|
|
||||||
|
|
||||||
from .test_tools_common import ToolTesterMixin, output_type
|
|
||||||
|
|
||||||
|
|
||||||
class TranslationToolTester(unittest.TestCase, ToolTesterMixin):
|
|
||||||
def setUp(self):
|
|
||||||
self.tool = load_tool("translation")
|
|
||||||
self.tool.setup()
|
|
||||||
self.remote_tool = load_tool("translation", remote=True)
|
|
||||||
|
|
||||||
def test_exact_match_arg(self):
|
|
||||||
result = self.tool("Hey, what's up?", src_lang="English", tgt_lang="French")
|
|
||||||
self.assertEqual(result, "- Hé, comment ça va?")
|
|
||||||
|
|
||||||
def test_exact_match_kwarg(self):
|
|
||||||
result = self.tool(
|
|
||||||
text="Hey, what's up?", src_lang="English", tgt_lang="French"
|
|
||||||
)
|
|
||||||
self.assertEqual(result, "- Hé, comment ça va?")
|
|
||||||
|
|
||||||
def test_call(self):
|
|
||||||
inputs = ["Hey, what's up?", "English", "Spanish"]
|
|
||||||
output = self.tool(*inputs)
|
|
||||||
|
|
||||||
self.assertEqual(output_type(output), self.tool.output_type)
|
|
||||||
|
|
||||||
def test_agent_type_output(self):
|
|
||||||
inputs = ["Hey, what's up?", "English", "Spanish"]
|
|
||||||
output = self.tool(*inputs)
|
|
||||||
output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
|
|
||||||
self.assertTrue(isinstance(output, output_type))
|
|
||||||
|
|
||||||
def test_agent_types_inputs(self):
|
|
||||||
example_inputs = {
|
|
||||||
"text": "Hey, what's up?",
|
|
||||||
"src_lang": "English",
|
|
||||||
"tgt_lang": "Spanish",
|
|
||||||
}
|
|
||||||
|
|
||||||
_inputs = []
|
|
||||||
for input_name in example_inputs.keys():
|
|
||||||
example_input = example_inputs[input_name]
|
|
||||||
input_description = self.tool.inputs[input_name]
|
|
||||||
input_type = input_description["type"]
|
|
||||||
_inputs.append(AGENT_TYPE_MAPPING[input_type](example_input))
|
|
||||||
|
|
||||||
# Should not raise an error
|
|
||||||
output = self.tool(**example_inputs)
|
|
||||||
output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
|
|
||||||
self.assertTrue(isinstance(output, output_type))
|
|
|
@ -0,0 +1,100 @@
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
import shutil
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
def str_to_bool(value) -> int:
|
||||||
|
"""
|
||||||
|
Converts a string representation of truth to `True` (1) or `False` (0).
|
||||||
|
|
||||||
|
True values are `y`, `yes`, `t`, `true`, `on`, and `1`; False value are `n`, `no`, `f`, `false`, `off`, and `0`;
|
||||||
|
"""
|
||||||
|
value = value.lower()
|
||||||
|
if value in ("y", "yes", "t", "true", "on", "1"):
|
||||||
|
return 1
|
||||||
|
elif value in ("n", "no", "f", "false", "off", "0"):
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
raise ValueError(f"invalid truth value {value}")
|
||||||
|
|
||||||
|
|
||||||
|
def get_int_from_env(env_keys, default):
|
||||||
|
"""Returns the first positive env value found in the `env_keys` list or the default."""
|
||||||
|
for e in env_keys:
|
||||||
|
val = int(os.environ.get(e, -1))
|
||||||
|
if val >= 0:
|
||||||
|
return val
|
||||||
|
return default
|
||||||
|
|
||||||
|
def parse_flag_from_env(key, default=False):
|
||||||
|
"""Returns truthy value for `key` from the env if available else the default."""
|
||||||
|
value = os.environ.get(key, str(default))
|
||||||
|
return str_to_bool(value) == 1 # As its name indicates `str_to_bool` actually returns an int...
|
||||||
|
|
||||||
|
|
||||||
|
_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
|
||||||
|
|
||||||
|
|
||||||
|
def skip(test_case):
|
||||||
|
"Decorator that skips a test unconditionally"
|
||||||
|
return unittest.skip("Test was skipped")(test_case)
|
||||||
|
|
||||||
|
|
||||||
|
def slow(test_case):
|
||||||
|
"""
|
||||||
|
Decorator marking a test as slow. Slow tests are skipped by default. Set the RUN_SLOW environment variable to a
|
||||||
|
truthy value to run them.
|
||||||
|
"""
|
||||||
|
return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
|
||||||
|
|
||||||
|
def get_launch_command(**kwargs) -> list:
|
||||||
|
"""
|
||||||
|
Wraps around `kwargs` to help simplify launching from `subprocess`.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
# returns ['accelerate', 'launch', '--num_processes=2', '--device_count=2']
|
||||||
|
get_launch_command(num_processes=2, device_count=2)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
command = ["accelerate", "launch"]
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
if isinstance(v, bool) and v:
|
||||||
|
command.append(f"--{k}")
|
||||||
|
elif v is not None:
|
||||||
|
command.append(f"--{k}={v}")
|
||||||
|
return command
|
||||||
|
|
||||||
|
|
||||||
|
class TempDirTestCase(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
A TestCase class that keeps a single `tempfile.TemporaryDirectory` open for the duration of the class, wipes its
|
||||||
|
data at the start of a test, and then destroyes it at the end of the TestCase.
|
||||||
|
|
||||||
|
Useful for when a class or API requires a single constant folder throughout it's use, such as Weights and Biases
|
||||||
|
|
||||||
|
The temporary directory location will be stored in `self.tmpdir`
|
||||||
|
"""
|
||||||
|
|
||||||
|
clear_on_setup = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
"Creates a `tempfile.TemporaryDirectory` and stores it in `cls.tmpdir`"
|
||||||
|
cls.tmpdir = Path(tempfile.mkdtemp())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
"Remove `cls.tmpdir` after test suite has finished"
|
||||||
|
if os.path.exists(cls.tmpdir):
|
||||||
|
shutil.rmtree(cls.tmpdir)
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"Destroy all contents in `self.tmpdir`, but not `self.tmpdir`"
|
||||||
|
if self.clear_on_setup:
|
||||||
|
for path in self.tmpdir.glob("**/*"):
|
||||||
|
if path.is_file():
|
||||||
|
path.unlink()
|
||||||
|
elif path.is_dir():
|
||||||
|
shutil.rmtree(path)
|
Loading…
Reference in New Issue