Add linter rules + apply make style (#255)
* Add linter rules + apply make style
This commit is contained in:
parent
5aa0f2b53d
commit
6e1373a324
|
@ -181,6 +181,7 @@
|
||||||
"import datasets\n",
|
"import datasets\n",
|
||||||
"import pandas as pd\n",
|
"import pandas as pd\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"\n",
|
||||||
"eval_ds = datasets.load_dataset(\"m-ric/smol_agents_benchmark\")[\"test\"]\n",
|
"eval_ds = datasets.load_dataset(\"m-ric/smol_agents_benchmark\")[\"test\"]\n",
|
||||||
"pd.DataFrame(eval_ds)"
|
"pd.DataFrame(eval_ds)"
|
||||||
]
|
]
|
||||||
|
@ -199,26 +200,28 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"import time\n",
|
|
||||||
"import json\n",
|
"import json\n",
|
||||||
"import os\n",
|
"import os\n",
|
||||||
"import re\n",
|
"import re\n",
|
||||||
"import string\n",
|
"import string\n",
|
||||||
|
"import time\n",
|
||||||
"import warnings\n",
|
"import warnings\n",
|
||||||
"from tqdm import tqdm\n",
|
|
||||||
"from typing import List\n",
|
"from typing import List\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"from dotenv import load_dotenv\n",
|
||||||
|
"from tqdm import tqdm\n",
|
||||||
|
"\n",
|
||||||
"from smolagents import (\n",
|
"from smolagents import (\n",
|
||||||
" GoogleSearchTool,\n",
|
|
||||||
" CodeAgent,\n",
|
|
||||||
" ToolCallingAgent,\n",
|
|
||||||
" HfApiModel,\n",
|
|
||||||
" AgentError,\n",
|
" AgentError,\n",
|
||||||
" VisitWebpageTool,\n",
|
" CodeAgent,\n",
|
||||||
|
" GoogleSearchTool,\n",
|
||||||
|
" HfApiModel,\n",
|
||||||
" PythonInterpreterTool,\n",
|
" PythonInterpreterTool,\n",
|
||||||
|
" ToolCallingAgent,\n",
|
||||||
|
" VisitWebpageTool,\n",
|
||||||
")\n",
|
")\n",
|
||||||
"from smolagents.agents import ActionStep\n",
|
"from smolagents.agents import ActionStep\n",
|
||||||
"from dotenv import load_dotenv\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"load_dotenv()\n",
|
"load_dotenv()\n",
|
||||||
"os.makedirs(\"output\", exist_ok=True)\n",
|
"os.makedirs(\"output\", exist_ok=True)\n",
|
||||||
|
@ -231,9 +234,7 @@
|
||||||
" return str(obj)\n",
|
" return str(obj)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"def answer_questions(\n",
|
"def answer_questions(eval_ds, file_name, agent, model_id, action_type, is_vanilla_llm=False):\n",
|
||||||
" eval_ds, file_name, agent, model_id, action_type, is_vanilla_llm=False\n",
|
|
||||||
"):\n",
|
|
||||||
" answered_questions = []\n",
|
" answered_questions = []\n",
|
||||||
" if os.path.exists(file_name):\n",
|
" if os.path.exists(file_name):\n",
|
||||||
" with open(file_name, \"r\") as f:\n",
|
" with open(file_name, \"r\") as f:\n",
|
||||||
|
@ -365,23 +366,18 @@
|
||||||
" ma_elems = split_string(model_answer)\n",
|
" ma_elems = split_string(model_answer)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" if len(gt_elems) != len(ma_elems): # check length is the same\n",
|
" if len(gt_elems) != len(ma_elems): # check length is the same\n",
|
||||||
" warnings.warn(\n",
|
" warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n",
|
||||||
" \"Answer lists have different lengths, returning False.\", UserWarning\n",
|
|
||||||
" )\n",
|
|
||||||
" return False\n",
|
" return False\n",
|
||||||
"\n",
|
"\n",
|
||||||
" comparisons = []\n",
|
" comparisons = []\n",
|
||||||
" for ma_elem, gt_elem in zip(\n",
|
" for ma_elem, gt_elem in zip(ma_elems, gt_elems): # compare each element as float or str\n",
|
||||||
" ma_elems, gt_elems\n",
|
|
||||||
" ): # compare each element as float or str\n",
|
|
||||||
" if is_float(gt_elem):\n",
|
" if is_float(gt_elem):\n",
|
||||||
" normalized_ma_elem = normalize_number_str(ma_elem)\n",
|
" normalized_ma_elem = normalize_number_str(ma_elem)\n",
|
||||||
" comparisons.append(normalized_ma_elem == float(gt_elem))\n",
|
" comparisons.append(normalized_ma_elem == float(gt_elem))\n",
|
||||||
" else:\n",
|
" else:\n",
|
||||||
" # we do not remove punct since comparisons can include punct\n",
|
" # we do not remove punct since comparisons can include punct\n",
|
||||||
" comparisons.append(\n",
|
" comparisons.append(\n",
|
||||||
" normalize_str(ma_elem, remove_punct=False)\n",
|
" normalize_str(ma_elem, remove_punct=False) == normalize_str(gt_elem, remove_punct=False)\n",
|
||||||
" == normalize_str(gt_elem, remove_punct=False)\n",
|
|
||||||
" )\n",
|
" )\n",
|
||||||
" return all(comparisons)\n",
|
" return all(comparisons)\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
@ -441,9 +437,7 @@
|
||||||
" action_type = \"vanilla\"\n",
|
" action_type = \"vanilla\"\n",
|
||||||
" llm = HfApiModel(model_id)\n",
|
" llm = HfApiModel(model_id)\n",
|
||||||
" file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n",
|
" file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n",
|
||||||
" answer_questions(\n",
|
" answer_questions(eval_ds, file_name, llm, model_id, action_type, is_vanilla_llm=True)"
|
||||||
" eval_ds, file_name, llm, model_id, action_type, is_vanilla_llm=True\n",
|
|
||||||
" )"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -461,6 +455,7 @@
|
||||||
"source": [
|
"source": [
|
||||||
"from smolagents import LiteLLMModel\n",
|
"from smolagents import LiteLLMModel\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"\n",
|
||||||
"litellm_model_ids = [\"gpt-4o\", \"anthropic/claude-3-5-sonnet-latest\"]\n",
|
"litellm_model_ids = [\"gpt-4o\", \"anthropic/claude-3-5-sonnet-latest\"]\n",
|
||||||
"\n",
|
"\n",
|
||||||
"for model_id in litellm_model_ids:\n",
|
"for model_id in litellm_model_ids:\n",
|
||||||
|
@ -492,9 +487,7 @@
|
||||||
" action_type = \"vanilla\"\n",
|
" action_type = \"vanilla\"\n",
|
||||||
" llm = LiteLLMModel(model_id)\n",
|
" llm = LiteLLMModel(model_id)\n",
|
||||||
" file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n",
|
" file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n",
|
||||||
" answer_questions(\n",
|
" answer_questions(eval_ds, file_name, llm, model_id, action_type, is_vanilla_llm=True)"
|
||||||
" eval_ds, file_name, llm, model_id, action_type, is_vanilla_llm=True\n",
|
|
||||||
" )"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -556,9 +549,11 @@
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"import pandas as pd\n",
|
|
||||||
"import glob\n",
|
"import glob\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"import pandas as pd\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
"res = []\n",
|
"res = []\n",
|
||||||
"for file_path in glob.glob(\"output/*.jsonl\"):\n",
|
"for file_path in glob.glob(\"output/*.jsonl\"):\n",
|
||||||
" data = []\n",
|
" data = []\n",
|
||||||
|
@ -595,11 +590,7 @@
|
||||||
"\n",
|
"\n",
|
||||||
"result_df[\"correct\"] = result_df.apply(get_correct, axis=1)\n",
|
"result_df[\"correct\"] = result_df.apply(get_correct, axis=1)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"result_df = (\n",
|
"result_df = (result_df.groupby([\"model_id\", \"source\", \"action_type\"])[[\"correct\"]].mean() * 100).round(1).reset_index()"
|
||||||
" (result_df.groupby([\"model_id\", \"source\", \"action_type\"])[[\"correct\"]].mean() * 100)\n",
|
|
||||||
" .round(1)\n",
|
|
||||||
" .reset_index()\n",
|
|
||||||
")"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -895,6 +886,7 @@
|
||||||
"import pandas as pd\n",
|
"import pandas as pd\n",
|
||||||
"from matplotlib.legend_handler import HandlerTuple # Added import\n",
|
"from matplotlib.legend_handler import HandlerTuple # Added import\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"\n",
|
||||||
"# Assuming pivot_df is your original dataframe\n",
|
"# Assuming pivot_df is your original dataframe\n",
|
||||||
"models = pivot_df[\"model_id\"].unique()\n",
|
"models = pivot_df[\"model_id\"].unique()\n",
|
||||||
"sources = pivot_df[\"source\"].unique()\n",
|
"sources = pivot_df[\"source\"].unique()\n",
|
||||||
|
@ -961,14 +953,10 @@
|
||||||
"handles, labels = ax.get_legend_handles_labels()\n",
|
"handles, labels = ax.get_legend_handles_labels()\n",
|
||||||
"unique_sources = sources\n",
|
"unique_sources = sources\n",
|
||||||
"legend_elements = [\n",
|
"legend_elements = [\n",
|
||||||
" (handles[i * 2], handles[i * 2 + 1], labels[i * 2].replace(\" (Agent)\", \"\"))\n",
|
" (handles[i * 2], handles[i * 2 + 1], labels[i * 2].replace(\" (Agent)\", \"\")) for i in range(len(unique_sources))\n",
|
||||||
" for i in range(len(unique_sources))\n",
|
|
||||||
"]\n",
|
"]\n",
|
||||||
"custom_legend = ax.legend(\n",
|
"custom_legend = ax.legend(\n",
|
||||||
" [\n",
|
" [(agent_handle, vanilla_handle) for agent_handle, vanilla_handle, _ in legend_elements],\n",
|
||||||
" (agent_handle, vanilla_handle)\n",
|
|
||||||
" for agent_handle, vanilla_handle, _ in legend_elements\n",
|
|
||||||
" ],\n",
|
|
||||||
" [label for _, _, label in legend_elements],\n",
|
" [label for _, _, label in legend_elements],\n",
|
||||||
" handler_map={tuple: HandlerTuple(ndivide=None)},\n",
|
" handler_map={tuple: HandlerTuple(ndivide=None)},\n",
|
||||||
" bbox_to_anchor=(1.05, 1),\n",
|
" bbox_to_anchor=(1.05, 1),\n",
|
||||||
|
@ -1006,9 +994,7 @@
|
||||||
" # Start the matrix environment with 4 columns\n",
|
" # Start the matrix environment with 4 columns\n",
|
||||||
" # l for left-aligned model and task, c for centered numbers\n",
|
" # l for left-aligned model and task, c for centered numbers\n",
|
||||||
" mathjax_table = \"\\\\begin{array}{llcc}\\n\"\n",
|
" mathjax_table = \"\\\\begin{array}{llcc}\\n\"\n",
|
||||||
" mathjax_table += (\n",
|
" mathjax_table += \"\\\\text{Model} & \\\\text{Task} & \\\\text{Agent} & \\\\text{Vanilla} \\\\\\\\\\n\"\n",
|
||||||
" \"\\\\text{Model} & \\\\text{Task} & \\\\text{Agent} & \\\\text{Vanilla} \\\\\\\\\\n\"\n",
|
|
||||||
" )\n",
|
|
||||||
" mathjax_table += \"\\\\hline\\n\"\n",
|
" mathjax_table += \"\\\\hline\\n\"\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # Sort the DataFrame by model_id and source\n",
|
" # Sort the DataFrame by model_id and source\n",
|
||||||
|
@ -1033,9 +1019,7 @@
|
||||||
" model_display = \"\\\\;\"\n",
|
" model_display = \"\\\\;\"\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # Add the data row\n",
|
" # Add the data row\n",
|
||||||
" mathjax_table += (\n",
|
" mathjax_table += f\"{model_display} & {source} & {row['agent']} & {row['vanilla']} \\\\\\\\\\n\"\n",
|
||||||
" f\"{model_display} & {source} & {row['agent']} & {row['vanilla']} \\\\\\\\\\n\"\n",
|
|
||||||
" )\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
" current_model = model\n",
|
" current_model = model\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
from smolagents import Tool, CodeAgent, HfApiModel
|
|
||||||
from smolagents.default_tools import VisitWebpageTool
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
from smolagents import CodeAgent, HfApiModel, Tool
|
||||||
|
from smolagents.default_tools import VisitWebpageTool
|
||||||
|
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
@ -16,10 +18,11 @@ class GetCatImageTool(Tool):
|
||||||
self.url = "https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png"
|
self.url = "https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png"
|
||||||
|
|
||||||
def forward(self):
|
def forward(self):
|
||||||
from PIL import Image
|
|
||||||
import requests
|
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
response = requests.get(self.url)
|
response = requests.get(self.url)
|
||||||
|
|
||||||
return Image.open(BytesIO(response.content))
|
return Image.open(BytesIO(response.content))
|
||||||
|
@ -46,4 +49,5 @@ agent.run(
|
||||||
# Try the agent in a Gradio UI
|
# Try the agent in a Gradio UI
|
||||||
from smolagents import GradioUI
|
from smolagents import GradioUI
|
||||||
|
|
||||||
|
|
||||||
GradioUI(agent).launch()
|
GradioUI(agent).launch()
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
from smolagents import CodeAgent, HfApiModel, GradioUI
|
from smolagents import CodeAgent, GradioUI, HfApiModel
|
||||||
|
|
||||||
|
|
||||||
agent = CodeAgent(tools=[], model=HfApiModel(), max_steps=4, verbosity_level=1)
|
agent = CodeAgent(tools=[], model=HfApiModel(), max_steps=4, verbosity_level=1)
|
||||||
|
|
||||||
|
|
|
@ -1,24 +1,22 @@
|
||||||
|
from openinference.instrumentation.smolagents import SmolagentsInstrumentor
|
||||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
|
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
|
||||||
from opentelemetry.sdk.trace import TracerProvider
|
from opentelemetry.sdk.trace import TracerProvider
|
||||||
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
|
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
|
||||||
|
|
||||||
from openinference.instrumentation.smolagents import SmolagentsInstrumentor
|
|
||||||
|
|
||||||
from smolagents import (
|
from smolagents import (
|
||||||
CodeAgent,
|
CodeAgent,
|
||||||
DuckDuckGoSearchTool,
|
DuckDuckGoSearchTool,
|
||||||
VisitWebpageTool,
|
HfApiModel,
|
||||||
ManagedAgent,
|
ManagedAgent,
|
||||||
ToolCallingAgent,
|
ToolCallingAgent,
|
||||||
HfApiModel,
|
VisitWebpageTool,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Let's setup the instrumentation first
|
# Let's setup the instrumentation first
|
||||||
|
|
||||||
trace_provider = TracerProvider()
|
trace_provider = TracerProvider()
|
||||||
trace_provider.add_span_processor(
|
trace_provider.add_span_processor(SimpleSpanProcessor(OTLPSpanExporter("http://0.0.0.0:6006/v1/traces")))
|
||||||
SimpleSpanProcessor(OTLPSpanExporter("http://0.0.0.0:6006/v1/traces"))
|
|
||||||
)
|
|
||||||
|
|
||||||
SmolagentsInstrumentor().instrument(tracer_provider=trace_provider, skip_dep_check=True)
|
SmolagentsInstrumentor().instrument(tracer_provider=trace_provider, skip_dep_check=True)
|
||||||
|
|
||||||
|
@ -39,6 +37,4 @@ manager_agent = CodeAgent(
|
||||||
model=model,
|
model=model,
|
||||||
managed_agents=[managed_agent],
|
managed_agents=[managed_agent],
|
||||||
)
|
)
|
||||||
manager_agent.run(
|
manager_agent.run("If the US keeps it 2024 growth rate, how many years would it take for the GDP to double?")
|
||||||
"If the US keeps it 2024 growth rate, how many years would it take for the GDP to double?"
|
|
||||||
)
|
|
||||||
|
|
|
@ -8,13 +8,10 @@ from langchain_community.retrievers import BM25Retriever
|
||||||
|
|
||||||
|
|
||||||
knowledge_base = datasets.load_dataset("m-ric/huggingface_doc", split="train")
|
knowledge_base = datasets.load_dataset("m-ric/huggingface_doc", split="train")
|
||||||
knowledge_base = knowledge_base.filter(
|
knowledge_base = knowledge_base.filter(lambda row: row["source"].startswith("huggingface/transformers"))
|
||||||
lambda row: row["source"].startswith("huggingface/transformers")
|
|
||||||
)
|
|
||||||
|
|
||||||
source_docs = [
|
source_docs = [
|
||||||
Document(page_content=doc["text"], metadata={"source": doc["source"].split("/")[1]})
|
Document(page_content=doc["text"], metadata={"source": doc["source"].split("/")[1]}) for doc in knowledge_base
|
||||||
for doc in knowledge_base
|
|
||||||
]
|
]
|
||||||
|
|
||||||
text_splitter = RecursiveCharacterTextSplitter(
|
text_splitter = RecursiveCharacterTextSplitter(
|
||||||
|
@ -51,14 +48,12 @@ class RetrieverTool(Tool):
|
||||||
query,
|
query,
|
||||||
)
|
)
|
||||||
return "\nRetrieved documents:\n" + "".join(
|
return "\nRetrieved documents:\n" + "".join(
|
||||||
[
|
[f"\n\n===== Document {str(i)} =====\n" + doc.page_content for i, doc in enumerate(docs)]
|
||||||
f"\n\n===== Document {str(i)} =====\n" + doc.page_content
|
|
||||||
for i, doc in enumerate(docs)
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
from smolagents import HfApiModel, CodeAgent
|
from smolagents import CodeAgent, HfApiModel
|
||||||
|
|
||||||
|
|
||||||
retriever_tool = RetrieverTool(docs_processed)
|
retriever_tool = RetrieverTool(docs_processed)
|
||||||
agent = CodeAgent(
|
agent = CodeAgent(
|
||||||
|
@ -68,9 +63,7 @@ agent = CodeAgent(
|
||||||
verbosity_level=2,
|
verbosity_level=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
agent_output = agent.run(
|
agent_output = agent.run("For a transformers model training, which is slower, the forward or the backward pass?")
|
||||||
"For a transformers model training, which is slower, the forward or the backward pass?"
|
|
||||||
)
|
|
||||||
|
|
||||||
print("Final output:")
|
print("Final output:")
|
||||||
print(agent_output)
|
print(agent_output)
|
||||||
|
|
|
@ -1,16 +1,17 @@
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
create_engine,
|
|
||||||
MetaData,
|
|
||||||
Table,
|
|
||||||
Column,
|
Column,
|
||||||
String,
|
|
||||||
Integer,
|
|
||||||
Float,
|
Float,
|
||||||
|
Integer,
|
||||||
|
MetaData,
|
||||||
|
String,
|
||||||
|
Table,
|
||||||
|
create_engine,
|
||||||
insert,
|
insert,
|
||||||
inspect,
|
inspect,
|
||||||
text,
|
text,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
engine = create_engine("sqlite:///:memory:")
|
engine = create_engine("sqlite:///:memory:")
|
||||||
metadata_obj = MetaData()
|
metadata_obj = MetaData()
|
||||||
|
|
||||||
|
@ -40,9 +41,7 @@ for row in rows:
|
||||||
inspector = inspect(engine)
|
inspector = inspect(engine)
|
||||||
columns_info = [(col["name"], col["type"]) for col in inspector.get_columns("receipts")]
|
columns_info = [(col["name"], col["type"]) for col in inspector.get_columns("receipts")]
|
||||||
|
|
||||||
table_description = "Columns:\n" + "\n".join(
|
table_description = "Columns:\n" + "\n".join([f" - {name}: {col_type}" for name, col_type in columns_info])
|
||||||
[f" - {name}: {col_type}" for name, col_type in columns_info]
|
|
||||||
)
|
|
||||||
print(table_description)
|
print(table_description)
|
||||||
|
|
||||||
from smolagents import tool
|
from smolagents import tool
|
||||||
|
@ -72,6 +71,7 @@ def sql_engine(query: str) -> str:
|
||||||
|
|
||||||
from smolagents import CodeAgent, HfApiModel
|
from smolagents import CodeAgent, HfApiModel
|
||||||
|
|
||||||
|
|
||||||
agent = CodeAgent(
|
agent = CodeAgent(
|
||||||
tools=[sql_engine],
|
tools=[sql_engine],
|
||||||
model=HfApiModel("meta-llama/Meta-Llama-3.1-8B-Instruct"),
|
model=HfApiModel("meta-llama/Meta-Llama-3.1-8B-Instruct"),
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
from smolagents.agents import ToolCallingAgent
|
|
||||||
from smolagents import tool, LiteLLMModel
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from smolagents import LiteLLMModel, tool
|
||||||
|
from smolagents.agents import ToolCallingAgent
|
||||||
|
|
||||||
|
|
||||||
# Choose which LLM engine to use!
|
# Choose which LLM engine to use!
|
||||||
# model = HfApiModel(model_id="meta-llama/Llama-3.3-70B-Instruct")
|
# model = HfApiModel(model_id="meta-llama/Llama-3.3-70B-Instruct")
|
||||||
# model = TransformersModel(model_id="meta-llama/Llama-3.2-2B-Instruct")
|
# model = TransformersModel(model_id="meta-llama/Llama-3.2-2B-Instruct")
|
||||||
|
|
|
@ -13,8 +13,10 @@ Usage:
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from mcp import StdioServerParameters
|
from mcp import StdioServerParameters
|
||||||
|
|
||||||
from smolagents import CodeAgent, HfApiModel, ToolCollection
|
from smolagents import CodeAgent, HfApiModel, ToolCollection
|
||||||
|
|
||||||
|
|
||||||
mcp_server_params = StdioServerParameters(
|
mcp_server_params = StdioServerParameters(
|
||||||
command="uvx",
|
command="uvx",
|
||||||
args=["--quiet", "pubmedmcp@0.1.3"],
|
args=["--quiet", "pubmedmcp@0.1.3"],
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
from smolagents.agents import ToolCallingAgent
|
|
||||||
from smolagents import tool, LiteLLMModel
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from smolagents import LiteLLMModel, tool
|
||||||
|
from smolagents.agents import ToolCallingAgent
|
||||||
|
|
||||||
|
|
||||||
model = LiteLLMModel(
|
model = LiteLLMModel(
|
||||||
model_id="ollama_chat/llama3.2",
|
model_id="ollama_chat/llama3.2",
|
||||||
api_base="http://localhost:11434", # replace with remote open-ai compatible server if necessary
|
api_base="http://localhost:11434", # replace with remote open-ai compatible server if necessary
|
||||||
|
|
|
@ -60,9 +60,18 @@ dev = [
|
||||||
addopts = "-sv --durations=0"
|
addopts = "-sv --durations=0"
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
lint.ignore = ["F403"]
|
line-length = 119
|
||||||
|
lint.ignore = [
|
||||||
|
"F403", # undefined-local-with-import-star
|
||||||
|
"E501", # line-too-long
|
||||||
|
]
|
||||||
|
lint.select = ["E", "F", "I", "W"]
|
||||||
|
|
||||||
[tool.ruff.lint.per-file-ignores]
|
[tool.ruff.lint.per-file-ignores]
|
||||||
"examples/*" = [
|
"examples/*" = [
|
||||||
"E402", # module-import-not-at-top-of-file
|
"E402", # module-import-not-at-top-of-file
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[tool.ruff.lint.isort]
|
||||||
|
known-first-party = ["smolagents"]
|
||||||
|
lines-after-imports = 2
|
||||||
|
|
|
@ -21,6 +21,7 @@ from typing import TYPE_CHECKING
|
||||||
from transformers.utils import _LazyModule
|
from transformers.utils import _LazyModule
|
||||||
from transformers.utils.import_utils import define_import_structure
|
from transformers.utils.import_utils import define_import_structure
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .agents import *
|
from .agents import *
|
||||||
from .default_tools import *
|
from .default_tools import *
|
||||||
|
|
|
@ -16,18 +16,17 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from enum import IntEnum
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from enum import IntEnum
|
|
||||||
from rich import box
|
from rich import box
|
||||||
from rich.console import Group
|
from rich.console import Console, Group
|
||||||
from rich.panel import Panel
|
from rich.panel import Panel
|
||||||
from rich.rule import Rule
|
from rich.rule import Rule
|
||||||
from rich.syntax import Syntax
|
from rich.syntax import Syntax
|
||||||
from rich.text import Text
|
from rich.text import Text
|
||||||
from rich.console import Console
|
|
||||||
|
|
||||||
from .default_tools import FinalAnswerTool, TOOL_MAPPING
|
from .default_tools import TOOL_MAPPING, FinalAnswerTool
|
||||||
from .e2b_executor import E2BExecutor
|
from .e2b_executor import E2BExecutor
|
||||||
from .local_python_executor import (
|
from .local_python_executor import (
|
||||||
BASE_BUILTIN_MODULES,
|
BASE_BUILTIN_MODULES,
|
||||||
|
@ -112,20 +111,11 @@ class SystemPromptStep(AgentStepLog):
|
||||||
system_prompt: str
|
system_prompt: str
|
||||||
|
|
||||||
|
|
||||||
def get_tool_descriptions(
|
def get_tool_descriptions(tools: Dict[str, Tool], tool_description_template: str) -> str:
|
||||||
tools: Dict[str, Tool], tool_description_template: str
|
return "\n".join([get_tool_description_with_args(tool, tool_description_template) for tool in tools.values()])
|
||||||
) -> str:
|
|
||||||
return "\n".join(
|
|
||||||
[
|
|
||||||
get_tool_description_with_args(tool, tool_description_template)
|
|
||||||
for tool in tools.values()
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def format_prompt_with_tools(
|
def format_prompt_with_tools(tools: Dict[str, Tool], prompt_template: str, tool_description_template: str) -> str:
|
||||||
tools: Dict[str, Tool], prompt_template: str, tool_description_template: str
|
|
||||||
) -> str:
|
|
||||||
tool_descriptions = get_tool_descriptions(tools, tool_description_template)
|
tool_descriptions = get_tool_descriptions(tools, tool_description_template)
|
||||||
prompt = prompt_template.replace("{{tool_descriptions}}", tool_descriptions)
|
prompt = prompt_template.replace("{{tool_descriptions}}", tool_descriptions)
|
||||||
if "{{tool_names}}" in prompt:
|
if "{{tool_names}}" in prompt:
|
||||||
|
@ -159,9 +149,7 @@ def format_prompt_with_managed_agents_descriptions(
|
||||||
f"Provided prompt template does not contain the managed agents descriptions placeholder '{agent_descriptions_placeholder}'"
|
f"Provided prompt template does not contain the managed agents descriptions placeholder '{agent_descriptions_placeholder}'"
|
||||||
)
|
)
|
||||||
if len(managed_agents.keys()) > 0:
|
if len(managed_agents.keys()) > 0:
|
||||||
return prompt_template.replace(
|
return prompt_template.replace(agent_descriptions_placeholder, show_agents_descriptions(managed_agents))
|
||||||
agent_descriptions_placeholder, show_agents_descriptions(managed_agents)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return prompt_template.replace(agent_descriptions_placeholder, "")
|
return prompt_template.replace(agent_descriptions_placeholder, "")
|
||||||
|
|
||||||
|
@ -214,9 +202,7 @@ class MultiStepAgent:
|
||||||
self.model = model
|
self.model = model
|
||||||
self.system_prompt_template = system_prompt
|
self.system_prompt_template = system_prompt
|
||||||
self.tool_description_template = (
|
self.tool_description_template = (
|
||||||
tool_description_template
|
tool_description_template if tool_description_template else DEFAULT_TOOL_DESCRIPTION_TEMPLATE
|
||||||
if tool_description_template
|
|
||||||
else DEFAULT_TOOL_DESCRIPTION_TEMPLATE
|
|
||||||
)
|
)
|
||||||
self.max_steps = max_steps
|
self.max_steps = max_steps
|
||||||
self.tool_parser = tool_parser
|
self.tool_parser = tool_parser
|
||||||
|
@ -231,10 +217,7 @@ class MultiStepAgent:
|
||||||
self.tools = {tool.name: tool for tool in tools}
|
self.tools = {tool.name: tool for tool in tools}
|
||||||
if add_base_tools:
|
if add_base_tools:
|
||||||
for tool_name, tool_class in TOOL_MAPPING.items():
|
for tool_name, tool_class in TOOL_MAPPING.items():
|
||||||
if (
|
if tool_name != "python_interpreter" or self.__class__.__name__ == "ToolCallingAgent":
|
||||||
tool_name != "python_interpreter"
|
|
||||||
or self.__class__.__name__ == "ToolCallingAgent"
|
|
||||||
):
|
|
||||||
self.tools[tool_name] = tool_class()
|
self.tools[tool_name] = tool_class()
|
||||||
self.tools["final_answer"] = FinalAnswerTool()
|
self.tools["final_answer"] = FinalAnswerTool()
|
||||||
|
|
||||||
|
@ -253,15 +236,11 @@ class MultiStepAgent:
|
||||||
self.system_prompt_template,
|
self.system_prompt_template,
|
||||||
self.tool_description_template,
|
self.tool_description_template,
|
||||||
)
|
)
|
||||||
self.system_prompt = format_prompt_with_managed_agents_descriptions(
|
self.system_prompt = format_prompt_with_managed_agents_descriptions(self.system_prompt, self.managed_agents)
|
||||||
self.system_prompt, self.managed_agents
|
|
||||||
)
|
|
||||||
|
|
||||||
return self.system_prompt
|
return self.system_prompt
|
||||||
|
|
||||||
def write_inner_memory_from_logs(
|
def write_inner_memory_from_logs(self, summary_mode: Optional[bool] = False) -> List[Dict[str, str]]:
|
||||||
self, summary_mode: Optional[bool] = False
|
|
||||||
) -> List[Dict[str, str]]:
|
|
||||||
"""
|
"""
|
||||||
Reads past llm_outputs, actions, and observations or errors from the logs into a series of messages
|
Reads past llm_outputs, actions, and observations or errors from the logs into a series of messages
|
||||||
that can be used as input to the LLM.
|
that can be used as input to the LLM.
|
||||||
|
@ -355,10 +334,7 @@ class MultiStepAgent:
|
||||||
return memory
|
return memory
|
||||||
|
|
||||||
def get_succinct_logs(self):
|
def get_succinct_logs(self):
|
||||||
return [
|
return [{key: value for key, value in log.items() if key != "agent_memory"} for log in self.logs]
|
||||||
{key: value for key, value in log.items() if key != "agent_memory"}
|
|
||||||
for log in self.logs
|
|
||||||
]
|
|
||||||
|
|
||||||
def extract_action(self, llm_output: str, split_token: str) -> Tuple[str, str]:
|
def extract_action(self, llm_output: str, split_token: str) -> Tuple[str, str]:
|
||||||
"""
|
"""
|
||||||
|
@ -402,9 +378,7 @@ class MultiStepAgent:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return f"Error in generating final LLM output:\n{e}"
|
return f"Error in generating final LLM output:\n{e}"
|
||||||
|
|
||||||
def execute_tool_call(
|
def execute_tool_call(self, tool_name: str, arguments: Union[Dict[str, str], str]) -> Any:
|
||||||
self, tool_name: str, arguments: Union[Dict[str, str], str]
|
|
||||||
) -> Any:
|
|
||||||
"""
|
"""
|
||||||
Execute tool with the provided input and returns the result.
|
Execute tool with the provided input and returns the result.
|
||||||
This method replaces arguments with the actual values from the state if they refer to state variables.
|
This method replaces arguments with the actual values from the state if they refer to state variables.
|
||||||
|
@ -423,9 +397,7 @@ class MultiStepAgent:
|
||||||
if tool_name in self.managed_agents:
|
if tool_name in self.managed_agents:
|
||||||
observation = available_tools[tool_name].__call__(arguments)
|
observation = available_tools[tool_name].__call__(arguments)
|
||||||
else:
|
else:
|
||||||
observation = available_tools[tool_name].__call__(
|
observation = available_tools[tool_name].__call__(arguments, sanitize_inputs_outputs=True)
|
||||||
arguments, sanitize_inputs_outputs=True
|
|
||||||
)
|
|
||||||
elif isinstance(arguments, dict):
|
elif isinstance(arguments, dict):
|
||||||
for key, value in arguments.items():
|
for key, value in arguments.items():
|
||||||
if isinstance(value, str) and value in self.state:
|
if isinstance(value, str) and value in self.state:
|
||||||
|
@ -433,18 +405,14 @@ class MultiStepAgent:
|
||||||
if tool_name in self.managed_agents:
|
if tool_name in self.managed_agents:
|
||||||
observation = available_tools[tool_name].__call__(**arguments)
|
observation = available_tools[tool_name].__call__(**arguments)
|
||||||
else:
|
else:
|
||||||
observation = available_tools[tool_name].__call__(
|
observation = available_tools[tool_name].__call__(**arguments, sanitize_inputs_outputs=True)
|
||||||
**arguments, sanitize_inputs_outputs=True
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
error_msg = f"Arguments passed to tool should be a dict or string: got a {type(arguments)}."
|
error_msg = f"Arguments passed to tool should be a dict or string: got a {type(arguments)}."
|
||||||
raise AgentExecutionError(error_msg)
|
raise AgentExecutionError(error_msg)
|
||||||
return observation
|
return observation
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if tool_name in self.tools:
|
if tool_name in self.tools:
|
||||||
tool_description = get_tool_description_with_args(
|
tool_description = get_tool_description_with_args(available_tools[tool_name])
|
||||||
available_tools[tool_name]
|
|
||||||
)
|
|
||||||
error_msg = (
|
error_msg = (
|
||||||
f"Error in tool call execution: {e}\nYou should only use this tool with a correct input.\n"
|
f"Error in tool call execution: {e}\nYou should only use this tool with a correct input.\n"
|
||||||
f"As a reminder, this tool's description is the following:\n{tool_description}"
|
f"As a reminder, this tool's description is the following:\n{tool_description}"
|
||||||
|
@ -544,10 +512,7 @@ You have been provided with these additional arguments, that you can access usin
|
||||||
step_start_time = time.time()
|
step_start_time = time.time()
|
||||||
step_log = ActionStep(step=self.step_number, start_time=step_start_time)
|
step_log = ActionStep(step=self.step_number, start_time=step_start_time)
|
||||||
try:
|
try:
|
||||||
if (
|
if self.planning_interval is not None and self.step_number % self.planning_interval == 0:
|
||||||
self.planning_interval is not None
|
|
||||||
and self.step_number % self.planning_interval == 0
|
|
||||||
):
|
|
||||||
self.planning_step(
|
self.planning_step(
|
||||||
task,
|
task,
|
||||||
is_first_step=(self.step_number == 0),
|
is_first_step=(self.step_number == 0),
|
||||||
|
@ -600,10 +565,7 @@ You have been provided with these additional arguments, that you can access usin
|
||||||
step_start_time = time.time()
|
step_start_time = time.time()
|
||||||
step_log = ActionStep(step=self.step_number, start_time=step_start_time)
|
step_log = ActionStep(step=self.step_number, start_time=step_start_time)
|
||||||
try:
|
try:
|
||||||
if (
|
if self.planning_interval is not None and self.step_number % self.planning_interval == 0:
|
||||||
self.planning_interval is not None
|
|
||||||
and self.step_number % self.planning_interval == 0
|
|
||||||
):
|
|
||||||
self.planning_step(
|
self.planning_step(
|
||||||
task,
|
task,
|
||||||
is_first_step=(self.step_number == 0),
|
is_first_step=(self.step_number == 0),
|
||||||
|
@ -668,9 +630,7 @@ You have been provided with these additional arguments, that you can access usin
|
||||||
Now begin!""",
|
Now begin!""",
|
||||||
}
|
}
|
||||||
|
|
||||||
answer_facts = self.model(
|
answer_facts = self.model([message_prompt_facts, message_prompt_task]).content
|
||||||
[message_prompt_facts, message_prompt_task]
|
|
||||||
).content
|
|
||||||
|
|
||||||
message_system_prompt_plan = {
|
message_system_prompt_plan = {
|
||||||
"role": MessageRole.SYSTEM,
|
"role": MessageRole.SYSTEM,
|
||||||
|
@ -680,12 +640,8 @@ Now begin!""",
|
||||||
"role": MessageRole.USER,
|
"role": MessageRole.USER,
|
||||||
"content": USER_PROMPT_PLAN.format(
|
"content": USER_PROMPT_PLAN.format(
|
||||||
task=task,
|
task=task,
|
||||||
tool_descriptions=get_tool_descriptions(
|
tool_descriptions=get_tool_descriptions(self.tools, self.tool_description_template),
|
||||||
self.tools, self.tool_description_template
|
managed_agents_descriptions=(show_agents_descriptions(self.managed_agents)),
|
||||||
),
|
|
||||||
managed_agents_descriptions=(
|
|
||||||
show_agents_descriptions(self.managed_agents)
|
|
||||||
),
|
|
||||||
answer_facts=answer_facts,
|
answer_facts=answer_facts,
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
@ -702,9 +658,7 @@ Now begin!""",
|
||||||
```
|
```
|
||||||
{answer_facts}
|
{answer_facts}
|
||||||
```""".strip()
|
```""".strip()
|
||||||
self.logs.append(
|
self.logs.append(PlanningStep(plan=final_plan_redaction, facts=final_facts_redaction))
|
||||||
PlanningStep(plan=final_plan_redaction, facts=final_facts_redaction)
|
|
||||||
)
|
|
||||||
self.logger.log(
|
self.logger.log(
|
||||||
Rule("[bold]Initial plan", style="orange"),
|
Rule("[bold]Initial plan", style="orange"),
|
||||||
Text(final_plan_redaction),
|
Text(final_plan_redaction),
|
||||||
|
@ -724,9 +678,7 @@ Now begin!""",
|
||||||
"role": MessageRole.USER,
|
"role": MessageRole.USER,
|
||||||
"content": USER_PROMPT_FACTS_UPDATE,
|
"content": USER_PROMPT_FACTS_UPDATE,
|
||||||
}
|
}
|
||||||
facts_update = self.model(
|
facts_update = self.model([facts_update_system_prompt] + agent_memory + [facts_update_message]).content
|
||||||
[facts_update_system_prompt] + agent_memory + [facts_update_message]
|
|
||||||
).content
|
|
||||||
|
|
||||||
# Redact updated plan
|
# Redact updated plan
|
||||||
plan_update_message = {
|
plan_update_message = {
|
||||||
|
@ -737,12 +689,8 @@ Now begin!""",
|
||||||
"role": MessageRole.USER,
|
"role": MessageRole.USER,
|
||||||
"content": USER_PROMPT_PLAN_UPDATE.format(
|
"content": USER_PROMPT_PLAN_UPDATE.format(
|
||||||
task=task,
|
task=task,
|
||||||
tool_descriptions=get_tool_descriptions(
|
tool_descriptions=get_tool_descriptions(self.tools, self.tool_description_template),
|
||||||
self.tools, self.tool_description_template
|
managed_agents_descriptions=(show_agents_descriptions(self.managed_agents)),
|
||||||
),
|
|
||||||
managed_agents_descriptions=(
|
|
||||||
show_agents_descriptions(self.managed_agents)
|
|
||||||
),
|
|
||||||
facts_update=facts_update,
|
facts_update=facts_update,
|
||||||
remaining_steps=(self.max_steps - step),
|
remaining_steps=(self.max_steps - step),
|
||||||
),
|
),
|
||||||
|
@ -753,16 +701,12 @@ Now begin!""",
|
||||||
).content
|
).content
|
||||||
|
|
||||||
# Log final facts and plan
|
# Log final facts and plan
|
||||||
final_plan_redaction = PLAN_UPDATE_FINAL_PLAN_REDACTION.format(
|
final_plan_redaction = PLAN_UPDATE_FINAL_PLAN_REDACTION.format(task=task, plan_update=plan_update)
|
||||||
task=task, plan_update=plan_update
|
|
||||||
)
|
|
||||||
final_facts_redaction = f"""Here is the updated list of the facts that I know:
|
final_facts_redaction = f"""Here is the updated list of the facts that I know:
|
||||||
```
|
```
|
||||||
{facts_update}
|
{facts_update}
|
||||||
```"""
|
```"""
|
||||||
self.logs.append(
|
self.logs.append(PlanningStep(plan=final_plan_redaction, facts=final_facts_redaction))
|
||||||
PlanningStep(plan=final_plan_redaction, facts=final_facts_redaction)
|
|
||||||
)
|
|
||||||
self.logger.log(
|
self.logger.log(
|
||||||
Rule("[bold]Updated plan", style="orange"),
|
Rule("[bold]Updated plan", style="orange"),
|
||||||
Text(final_plan_redaction),
|
Text(final_plan_redaction),
|
||||||
|
@ -816,19 +760,13 @@ class ToolCallingAgent(MultiStepAgent):
|
||||||
tool_arguments = tool_call.function.arguments
|
tool_arguments = tool_call.function.arguments
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise AgentGenerationError(
|
raise AgentGenerationError(f"Error in generating tool call with model:\n{e}")
|
||||||
f"Error in generating tool call with model:\n{e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
log_entry.tool_calls = [
|
log_entry.tool_calls = [ToolCall(name=tool_name, arguments=tool_arguments, id=tool_call_id)]
|
||||||
ToolCall(name=tool_name, arguments=tool_arguments, id=tool_call_id)
|
|
||||||
]
|
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
self.logger.log(
|
self.logger.log(
|
||||||
Panel(
|
Panel(Text(f"Calling tool: '{tool_name}' with arguments: {tool_arguments}")),
|
||||||
Text(f"Calling tool: '{tool_name}' with arguments: {tool_arguments}")
|
|
||||||
),
|
|
||||||
level=LogLevel.INFO,
|
level=LogLevel.INFO,
|
||||||
)
|
)
|
||||||
if tool_name == "final_answer":
|
if tool_name == "final_answer":
|
||||||
|
@ -900,16 +838,10 @@ class CodeAgent(MultiStepAgent):
|
||||||
if system_prompt is None:
|
if system_prompt is None:
|
||||||
system_prompt = CODE_SYSTEM_PROMPT
|
system_prompt = CODE_SYSTEM_PROMPT
|
||||||
|
|
||||||
self.additional_authorized_imports = (
|
self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else []
|
||||||
additional_authorized_imports if additional_authorized_imports else []
|
self.authorized_imports = list(set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports))
|
||||||
)
|
|
||||||
self.authorized_imports = list(
|
|
||||||
set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports)
|
|
||||||
)
|
|
||||||
if "{{authorized_imports}}" not in system_prompt:
|
if "{{authorized_imports}}" not in system_prompt:
|
||||||
raise AgentError(
|
raise AgentError("Tag '{{authorized_imports}}' should be provided in the prompt.")
|
||||||
"Tag '{{authorized_imports}}' should be provided in the prompt."
|
|
||||||
)
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
tools=tools,
|
tools=tools,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -966,9 +898,7 @@ class CodeAgent(MultiStepAgent):
|
||||||
log_entry.agent_memory = agent_memory.copy()
|
log_entry.agent_memory = agent_memory.copy()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
additional_args = (
|
additional_args = {"grammar": self.grammar} if self.grammar is not None else {}
|
||||||
{"grammar": self.grammar} if self.grammar is not None else {}
|
|
||||||
)
|
|
||||||
llm_output = self.model(
|
llm_output = self.model(
|
||||||
self.input_messages,
|
self.input_messages,
|
||||||
stop_sequences=["<end_code>", "Observation:"],
|
stop_sequences=["<end_code>", "Observation:"],
|
||||||
|
@ -999,9 +929,7 @@ class CodeAgent(MultiStepAgent):
|
||||||
try:
|
try:
|
||||||
code_action = fix_final_answer_code(parse_code_blobs(llm_output))
|
code_action = fix_final_answer_code(parse_code_blobs(llm_output))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = (
|
error_msg = f"Error in code parsing:\n{e}\nMake sure to provide correct code blobs."
|
||||||
f"Error in code parsing:\n{e}\nMake sure to provide correct code blobs."
|
|
||||||
)
|
|
||||||
raise AgentParsingError(error_msg)
|
raise AgentParsingError(error_msg)
|
||||||
|
|
||||||
log_entry.tool_calls = [
|
log_entry.tool_calls = [
|
||||||
|
@ -1088,17 +1016,13 @@ class ManagedAgent:
|
||||||
self.description = description
|
self.description = description
|
||||||
self.additional_prompting = additional_prompting
|
self.additional_prompting = additional_prompting
|
||||||
self.provide_run_summary = provide_run_summary
|
self.provide_run_summary = provide_run_summary
|
||||||
self.managed_agent_prompt = (
|
self.managed_agent_prompt = managed_agent_prompt if managed_agent_prompt else MANAGED_AGENT_PROMPT
|
||||||
managed_agent_prompt if managed_agent_prompt else MANAGED_AGENT_PROMPT
|
|
||||||
)
|
|
||||||
|
|
||||||
def write_full_task(self, task):
|
def write_full_task(self, task):
|
||||||
"""Adds additional prompting for the managed agent, like 'add more detail in your answer'."""
|
"""Adds additional prompting for the managed agent, like 'add more detail in your answer'."""
|
||||||
full_task = self.managed_agent_prompt.format(name=self.name, task=task)
|
full_task = self.managed_agent_prompt.format(name=self.name, task=task)
|
||||||
if self.additional_prompting:
|
if self.additional_prompting:
|
||||||
full_task = full_task.replace(
|
full_task = full_task.replace("\n{{additional_prompting}}", self.additional_prompting).strip()
|
||||||
"\n{{additional_prompting}}", self.additional_prompting
|
|
||||||
).strip()
|
|
||||||
else:
|
else:
|
||||||
full_task = full_task.replace("\n{{additional_prompting}}", "").strip()
|
full_task = full_task.replace("\n{{additional_prompting}}", "").strip()
|
||||||
return full_task
|
return full_task
|
||||||
|
@ -1107,9 +1031,7 @@ class ManagedAgent:
|
||||||
full_task = self.write_full_task(request)
|
full_task = self.write_full_task(request)
|
||||||
output = self.agent.run(full_task, **kwargs)
|
output = self.agent.run(full_task, **kwargs)
|
||||||
if self.provide_run_summary:
|
if self.provide_run_summary:
|
||||||
answer = (
|
answer = f"Here is the final answer from your managed agent '{self.name}':\n"
|
||||||
f"Here is the final answer from your managed agent '{self.name}':\n"
|
|
||||||
)
|
|
||||||
answer += str(output)
|
answer += str(output)
|
||||||
answer += f"\n\nFor more detail, find below a summary of this agent's work:\nSUMMARY OF WORK FROM AGENT '{self.name}':\n"
|
answer += f"\n\nFor more detail, find below a summary of this agent's work:\nSUMMARY OF WORK FROM AGENT '{self.name}':\n"
|
||||||
for message in self.agent.write_inner_memory_from_logs(summary_mode=True):
|
for message in self.agent.write_inner_memory_from_logs(summary_mode=True):
|
||||||
|
|
|
@ -20,8 +20,6 @@ from dataclasses import dataclass
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
from huggingface_hub import hf_hub_download, list_spaces
|
from huggingface_hub import hf_hub_download, list_spaces
|
||||||
|
|
||||||
|
|
||||||
from transformers.utils import is_offline_mode, is_torch_available
|
from transformers.utils import is_offline_mode, is_torch_available
|
||||||
|
|
||||||
from .local_python_executor import (
|
from .local_python_executor import (
|
||||||
|
@ -32,6 +30,7 @@ from .local_python_executor import (
|
||||||
from .tools import TOOL_CONFIG_FILE, PipelineTool, Tool
|
from .tools import TOOL_CONFIG_FILE, PipelineTool, Tool
|
||||||
from .types import AgentAudio
|
from .types import AgentAudio
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
from transformers.models.whisper import (
|
from transformers.models.whisper import (
|
||||||
WhisperForConditionalGeneration,
|
WhisperForConditionalGeneration,
|
||||||
|
@ -61,9 +60,7 @@ def get_remote_tools(logger, organization="huggingface-tools"):
|
||||||
tools = {}
|
tools = {}
|
||||||
for space_info in spaces:
|
for space_info in spaces:
|
||||||
repo_id = space_info.id
|
repo_id = space_info.id
|
||||||
resolved_config_file = hf_hub_download(
|
resolved_config_file = hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="space")
|
||||||
repo_id, TOOL_CONFIG_FILE, repo_type="space"
|
|
||||||
)
|
|
||||||
with open(resolved_config_file, encoding="utf-8") as reader:
|
with open(resolved_config_file, encoding="utf-8") as reader:
|
||||||
config = json.load(reader)
|
config = json.load(reader)
|
||||||
task = repo_id.split("/")[-1]
|
task = repo_id.split("/")[-1]
|
||||||
|
@ -94,9 +91,7 @@ class PythonInterpreterTool(Tool):
|
||||||
if authorized_imports is None:
|
if authorized_imports is None:
|
||||||
self.authorized_imports = list(set(BASE_BUILTIN_MODULES))
|
self.authorized_imports = list(set(BASE_BUILTIN_MODULES))
|
||||||
else:
|
else:
|
||||||
self.authorized_imports = list(
|
self.authorized_imports = list(set(BASE_BUILTIN_MODULES) | set(authorized_imports))
|
||||||
set(BASE_BUILTIN_MODULES) | set(authorized_imports)
|
|
||||||
)
|
|
||||||
self.inputs = {
|
self.inputs = {
|
||||||
"code": {
|
"code": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
|
@ -126,9 +121,7 @@ class PythonInterpreterTool(Tool):
|
||||||
class FinalAnswerTool(Tool):
|
class FinalAnswerTool(Tool):
|
||||||
name = "final_answer"
|
name = "final_answer"
|
||||||
description = "Provides a final answer to the given problem."
|
description = "Provides a final answer to the given problem."
|
||||||
inputs = {
|
inputs = {"answer": {"type": "any", "description": "The final answer to the problem"}}
|
||||||
"answer": {"type": "any", "description": "The final answer to the problem"}
|
|
||||||
}
|
|
||||||
output_type = "any"
|
output_type = "any"
|
||||||
|
|
||||||
def forward(self, answer):
|
def forward(self, answer):
|
||||||
|
@ -138,9 +131,7 @@ class FinalAnswerTool(Tool):
|
||||||
class UserInputTool(Tool):
|
class UserInputTool(Tool):
|
||||||
name = "user_input"
|
name = "user_input"
|
||||||
description = "Asks for user's input on a specific question"
|
description = "Asks for user's input on a specific question"
|
||||||
inputs = {
|
inputs = {"question": {"type": "string", "description": "The question to ask the user"}}
|
||||||
"question": {"type": "string", "description": "The question to ask the user"}
|
|
||||||
}
|
|
||||||
output_type = "string"
|
output_type = "string"
|
||||||
|
|
||||||
def forward(self, question):
|
def forward(self, question):
|
||||||
|
@ -151,9 +142,7 @@ class UserInputTool(Tool):
|
||||||
class DuckDuckGoSearchTool(Tool):
|
class DuckDuckGoSearchTool(Tool):
|
||||||
name = "web_search"
|
name = "web_search"
|
||||||
description = """Performs a duckduckgo web search based on your query (think a Google search) then returns the top search results."""
|
description = """Performs a duckduckgo web search based on your query (think a Google search) then returns the top search results."""
|
||||||
inputs = {
|
inputs = {"query": {"type": "string", "description": "The search query to perform."}}
|
||||||
"query": {"type": "string", "description": "The search query to perform."}
|
|
||||||
}
|
|
||||||
output_type = "string"
|
output_type = "string"
|
||||||
|
|
||||||
def __init__(self, *args, max_results=10, **kwargs):
|
def __init__(self, *args, max_results=10, **kwargs):
|
||||||
|
@ -169,10 +158,7 @@ class DuckDuckGoSearchTool(Tool):
|
||||||
|
|
||||||
def forward(self, query: str) -> str:
|
def forward(self, query: str) -> str:
|
||||||
results = self.ddgs.text(query, max_results=self.max_results)
|
results = self.ddgs.text(query, max_results=self.max_results)
|
||||||
postprocessed_results = [
|
postprocessed_results = [f"[{result['title']}]({result['href']})\n{result['body']}" for result in results]
|
||||||
f"[{result['title']}]({result['href']})\n{result['body']}"
|
|
||||||
for result in results
|
|
||||||
]
|
|
||||||
return "## Search Results\n\n" + "\n\n".join(postprocessed_results)
|
return "## Search Results\n\n" + "\n\n".join(postprocessed_results)
|
||||||
|
|
||||||
|
|
||||||
|
@ -199,9 +185,7 @@ class GoogleSearchTool(Tool):
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
if self.serpapi_key is None:
|
if self.serpapi_key is None:
|
||||||
raise ValueError(
|
raise ValueError("Missing SerpAPI key. Make sure you have 'SERPAPI_API_KEY' in your env variables.")
|
||||||
"Missing SerpAPI key. Make sure you have 'SERPAPI_API_KEY' in your env variables."
|
|
||||||
)
|
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"engine": "google",
|
"engine": "google",
|
||||||
|
@ -210,9 +194,7 @@ class GoogleSearchTool(Tool):
|
||||||
"google_domain": "google.com",
|
"google_domain": "google.com",
|
||||||
}
|
}
|
||||||
if filter_year is not None:
|
if filter_year is not None:
|
||||||
params["tbs"] = (
|
params["tbs"] = f"cdr:1,cd_min:01/01/{filter_year},cd_max:12/31/{filter_year}"
|
||||||
f"cdr:1,cd_min:01/01/{filter_year},cd_max:12/31/{filter_year}"
|
|
||||||
)
|
|
||||||
|
|
||||||
response = requests.get("https://serpapi.com/search.json", params=params)
|
response = requests.get("https://serpapi.com/search.json", params=params)
|
||||||
|
|
||||||
|
@ -227,13 +209,9 @@ class GoogleSearchTool(Tool):
|
||||||
f"'organic_results' key not found for query: '{query}' with filtering on year={filter_year}. Use a less restrictive query or do not filter on year."
|
f"'organic_results' key not found for query: '{query}' with filtering on year={filter_year}. Use a less restrictive query or do not filter on year."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise Exception(
|
raise Exception(f"'organic_results' key not found for query: '{query}'. Use a less restrictive query.")
|
||||||
f"'organic_results' key not found for query: '{query}'. Use a less restrictive query."
|
|
||||||
)
|
|
||||||
if len(results["organic_results"]) == 0:
|
if len(results["organic_results"]) == 0:
|
||||||
year_filter_message = (
|
year_filter_message = f" with filter year={filter_year}" if filter_year is not None else ""
|
||||||
f" with filter year={filter_year}" if filter_year is not None else ""
|
|
||||||
)
|
|
||||||
return f"No results found for '{query}'{year_filter_message}. Try with a more general query, or remove the year filter."
|
return f"No results found for '{query}'{year_filter_message}. Try with a more general query, or remove the year filter."
|
||||||
|
|
||||||
web_snippets = []
|
web_snippets = []
|
||||||
|
@ -253,9 +231,7 @@ class GoogleSearchTool(Tool):
|
||||||
|
|
||||||
redacted_version = f"{idx}. [{page['title']}]({page['link']}){date_published}{source}\n{snippet}"
|
redacted_version = f"{idx}. [{page['title']}]({page['link']}){date_published}{source}\n{snippet}"
|
||||||
|
|
||||||
redacted_version = redacted_version.replace(
|
redacted_version = redacted_version.replace("Your browser can't play this video.", "")
|
||||||
"Your browser can't play this video.", ""
|
|
||||||
)
|
|
||||||
web_snippets.append(redacted_version)
|
web_snippets.append(redacted_version)
|
||||||
|
|
||||||
return "## Search Results\n" + "\n\n".join(web_snippets)
|
return "## Search Results\n" + "\n\n".join(web_snippets)
|
||||||
|
@ -263,7 +239,9 @@ class GoogleSearchTool(Tool):
|
||||||
|
|
||||||
class VisitWebpageTool(Tool):
|
class VisitWebpageTool(Tool):
|
||||||
name = "visit_webpage"
|
name = "visit_webpage"
|
||||||
description = "Visits a webpage at the given url and reads its content as a markdown string. Use this to browse webpages."
|
description = (
|
||||||
|
"Visits a webpage at the given url and reads its content as a markdown string. Use this to browse webpages."
|
||||||
|
)
|
||||||
inputs = {
|
inputs = {
|
||||||
"url": {
|
"url": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
|
@ -277,6 +255,7 @@ class VisitWebpageTool(Tool):
|
||||||
import requests
|
import requests
|
||||||
from markdownify import markdownify
|
from markdownify import markdownify
|
||||||
from requests.exceptions import RequestException
|
from requests.exceptions import RequestException
|
||||||
|
|
||||||
from smolagents.utils import truncate_content
|
from smolagents.utils import truncate_content
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
|
|
|
@ -28,6 +28,7 @@ from .tool_validation import validate_tool_attributes
|
||||||
from .tools import Tool
|
from .tools import Tool
|
||||||
from .utils import BASE_BUILTIN_MODULES, instance_to_source
|
from .utils import BASE_BUILTIN_MODULES, instance_to_source
|
||||||
|
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
@ -45,9 +46,7 @@ class E2BExecutor:
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
additional_imports = additional_imports + ["pickle5", "smolagents"]
|
additional_imports = additional_imports + ["pickle5", "smolagents"]
|
||||||
if len(additional_imports) > 0:
|
if len(additional_imports) > 0:
|
||||||
execution = self.sbx.commands.run(
|
execution = self.sbx.commands.run("pip install " + " ".join(additional_imports))
|
||||||
"pip install " + " ".join(additional_imports)
|
|
||||||
)
|
|
||||||
if execution.error:
|
if execution.error:
|
||||||
raise Exception(f"Error installing dependencies: {execution.error}")
|
raise Exception(f"Error installing dependencies: {execution.error}")
|
||||||
else:
|
else:
|
||||||
|
@ -61,9 +60,7 @@ class E2BExecutor:
|
||||||
tool_code += f"\n{tool.name} = {tool.__class__.__name__}()\n"
|
tool_code += f"\n{tool.name} = {tool.__class__.__name__}()\n"
|
||||||
tool_codes.append(tool_code)
|
tool_codes.append(tool_code)
|
||||||
|
|
||||||
tool_definition_code = "\n".join(
|
tool_definition_code = "\n".join([f"import {module}" for module in BASE_BUILTIN_MODULES])
|
||||||
[f"import {module}" for module in BASE_BUILTIN_MODULES]
|
|
||||||
)
|
|
||||||
tool_definition_code += textwrap.dedent("""
|
tool_definition_code += textwrap.dedent("""
|
||||||
class Tool:
|
class Tool:
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
|
@ -122,9 +119,7 @@ locals().update({key: value for key, value in pickle_dict.items()})
|
||||||
for attribute_name in ["jpeg", "png"]:
|
for attribute_name in ["jpeg", "png"]:
|
||||||
if getattr(result, attribute_name) is not None:
|
if getattr(result, attribute_name) is not None:
|
||||||
image_output = getattr(result, attribute_name)
|
image_output = getattr(result, attribute_name)
|
||||||
decoded_bytes = base64.b64decode(
|
decoded_bytes = base64.b64decode(image_output.encode("utf-8"))
|
||||||
image_output.encode("utf-8")
|
|
||||||
)
|
|
||||||
return Image.open(BytesIO(decoded_bytes)), execution_logs
|
return Image.open(BytesIO(decoded_bytes)), execution_logs
|
||||||
for attribute_name in [
|
for attribute_name in [
|
||||||
"chart",
|
"chart",
|
||||||
|
|
|
@ -13,14 +13,14 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import gradio as gr
|
|
||||||
import shutil
|
|
||||||
import os
|
|
||||||
import mimetypes
|
import mimetypes
|
||||||
|
import os
|
||||||
import re
|
import re
|
||||||
|
import shutil
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
from .agents import ActionStep, AgentStepLog, MultiStepAgent
|
from .agents import ActionStep, AgentStepLog, MultiStepAgent
|
||||||
from .types import AgentAudio, AgentImage, AgentText, handle_agent_output_types
|
from .types import AgentAudio, AgentImage, AgentText, handle_agent_output_types
|
||||||
|
|
||||||
|
@ -59,9 +59,7 @@ def stream_to_gradio(
|
||||||
):
|
):
|
||||||
"""Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages."""
|
"""Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages."""
|
||||||
|
|
||||||
for step_log in agent.run(
|
for step_log in agent.run(task, stream=True, reset=reset_agent_memory, additional_args=additional_args):
|
||||||
task, stream=True, reset=reset_agent_memory, additional_args=additional_args
|
|
||||||
):
|
|
||||||
for message in pull_messages_from_step(step_log, test_mode=test_mode):
|
for message in pull_messages_from_step(step_log, test_mode=test_mode):
|
||||||
yield message
|
yield message
|
||||||
|
|
||||||
|
@ -147,14 +145,10 @@ class GradioUI:
|
||||||
sanitized_name = "".join(sanitized_name)
|
sanitized_name = "".join(sanitized_name)
|
||||||
|
|
||||||
# Save the uploaded file to the specified folder
|
# Save the uploaded file to the specified folder
|
||||||
file_path = os.path.join(
|
file_path = os.path.join(self.file_upload_folder, os.path.basename(sanitized_name))
|
||||||
self.file_upload_folder, os.path.basename(sanitized_name)
|
|
||||||
)
|
|
||||||
shutil.copy(file.name, file_path)
|
shutil.copy(file.name, file_path)
|
||||||
|
|
||||||
return gr.Textbox(
|
return gr.Textbox(f"File uploaded: {file_path}", visible=True), file_uploads_log + [file_path]
|
||||||
f"File uploaded: {file_path}", visible=True
|
|
||||||
), file_uploads_log + [file_path]
|
|
||||||
|
|
||||||
def log_user_message(self, text_input, file_uploads_log):
|
def log_user_message(self, text_input, file_uploads_log):
|
||||||
return (
|
return (
|
||||||
|
@ -183,9 +177,7 @@ class GradioUI:
|
||||||
# If an upload folder is provided, enable the upload feature
|
# If an upload folder is provided, enable the upload feature
|
||||||
if self.file_upload_folder is not None:
|
if self.file_upload_folder is not None:
|
||||||
upload_file = gr.File(label="Upload a file")
|
upload_file = gr.File(label="Upload a file")
|
||||||
upload_status = gr.Textbox(
|
upload_status = gr.Textbox(label="Upload Status", interactive=False, visible=False)
|
||||||
label="Upload Status", interactive=False, visible=False
|
|
||||||
)
|
|
||||||
upload_file.change(
|
upload_file.change(
|
||||||
self.upload_file,
|
self.upload_file,
|
||||||
[upload_file, file_uploads_log],
|
[upload_file, file_uploads_log],
|
||||||
|
|
|
@ -42,8 +42,7 @@ class InterpreterError(ValueError):
|
||||||
ERRORS = {
|
ERRORS = {
|
||||||
name: getattr(builtins, name)
|
name: getattr(builtins, name)
|
||||||
for name in dir(builtins)
|
for name in dir(builtins)
|
||||||
if isinstance(getattr(builtins, name), type)
|
if isinstance(getattr(builtins, name), type) and issubclass(getattr(builtins, name), BaseException)
|
||||||
and issubclass(getattr(builtins, name), BaseException)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
PRINT_OUTPUTS, DEFAULT_MAX_LEN_OUTPUT = "", 50000
|
PRINT_OUTPUTS, DEFAULT_MAX_LEN_OUTPUT = "", 50000
|
||||||
|
@ -167,9 +166,7 @@ def evaluate_unaryop(
|
||||||
custom_tools: Dict[str, Callable],
|
custom_tools: Dict[str, Callable],
|
||||||
authorized_imports: List[str],
|
authorized_imports: List[str],
|
||||||
) -> Any:
|
) -> Any:
|
||||||
operand = evaluate_ast(
|
operand = evaluate_ast(expression.operand, state, static_tools, custom_tools, authorized_imports)
|
||||||
expression.operand, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
if isinstance(expression.op, ast.USub):
|
if isinstance(expression.op, ast.USub):
|
||||||
return -operand
|
return -operand
|
||||||
elif isinstance(expression.op, ast.UAdd):
|
elif isinstance(expression.op, ast.UAdd):
|
||||||
|
@ -179,9 +176,7 @@ def evaluate_unaryop(
|
||||||
elif isinstance(expression.op, ast.Invert):
|
elif isinstance(expression.op, ast.Invert):
|
||||||
return ~operand
|
return ~operand
|
||||||
else:
|
else:
|
||||||
raise InterpreterError(
|
raise InterpreterError(f"Unary operation {expression.op.__class__.__name__} is not supported.")
|
||||||
f"Unary operation {expression.op.__class__.__name__} is not supported."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def evaluate_lambda(
|
def evaluate_lambda(
|
||||||
|
@ -217,23 +212,17 @@ def evaluate_while(
|
||||||
) -> None:
|
) -> None:
|
||||||
max_iterations = 1000
|
max_iterations = 1000
|
||||||
iterations = 0
|
iterations = 0
|
||||||
while evaluate_ast(
|
while evaluate_ast(while_loop.test, state, static_tools, custom_tools, authorized_imports):
|
||||||
while_loop.test, state, static_tools, custom_tools, authorized_imports
|
|
||||||
):
|
|
||||||
for node in while_loop.body:
|
for node in while_loop.body:
|
||||||
try:
|
try:
|
||||||
evaluate_ast(
|
evaluate_ast(node, state, static_tools, custom_tools, authorized_imports)
|
||||||
node, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
except BreakException:
|
except BreakException:
|
||||||
return None
|
return None
|
||||||
except ContinueException:
|
except ContinueException:
|
||||||
break
|
break
|
||||||
iterations += 1
|
iterations += 1
|
||||||
if iterations > max_iterations:
|
if iterations > max_iterations:
|
||||||
raise InterpreterError(
|
raise InterpreterError(f"Maximum number of {max_iterations} iterations in While loop exceeded")
|
||||||
f"Maximum number of {max_iterations} iterations in While loop exceeded"
|
|
||||||
)
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@ -248,8 +237,7 @@ def create_function(
|
||||||
func_state = state.copy()
|
func_state = state.copy()
|
||||||
arg_names = [arg.arg for arg in func_def.args.args]
|
arg_names = [arg.arg for arg in func_def.args.args]
|
||||||
default_values = [
|
default_values = [
|
||||||
evaluate_ast(d, state, static_tools, custom_tools, authorized_imports)
|
evaluate_ast(d, state, static_tools, custom_tools, authorized_imports) for d in func_def.args.defaults
|
||||||
for d in func_def.args.defaults
|
|
||||||
]
|
]
|
||||||
|
|
||||||
# Apply default values
|
# Apply default values
|
||||||
|
@ -286,9 +274,7 @@ def create_function(
|
||||||
result = None
|
result = None
|
||||||
try:
|
try:
|
||||||
for stmt in func_def.body:
|
for stmt in func_def.body:
|
||||||
result = evaluate_ast(
|
result = evaluate_ast(stmt, func_state, static_tools, custom_tools, authorized_imports)
|
||||||
stmt, func_state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
except ReturnException as e:
|
except ReturnException as e:
|
||||||
result = e.value
|
result = e.value
|
||||||
|
|
||||||
|
@ -307,9 +293,7 @@ def evaluate_function_def(
|
||||||
custom_tools: Dict[str, Callable],
|
custom_tools: Dict[str, Callable],
|
||||||
authorized_imports: List[str],
|
authorized_imports: List[str],
|
||||||
) -> Callable:
|
) -> Callable:
|
||||||
custom_tools[func_def.name] = create_function(
|
custom_tools[func_def.name] = create_function(func_def, state, static_tools, custom_tools, authorized_imports)
|
||||||
func_def, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
return custom_tools[func_def.name]
|
return custom_tools[func_def.name]
|
||||||
|
|
||||||
|
|
||||||
|
@ -321,17 +305,12 @@ def evaluate_class_def(
|
||||||
authorized_imports: List[str],
|
authorized_imports: List[str],
|
||||||
) -> type:
|
) -> type:
|
||||||
class_name = class_def.name
|
class_name = class_def.name
|
||||||
bases = [
|
bases = [evaluate_ast(base, state, static_tools, custom_tools, authorized_imports) for base in class_def.bases]
|
||||||
evaluate_ast(base, state, static_tools, custom_tools, authorized_imports)
|
|
||||||
for base in class_def.bases
|
|
||||||
]
|
|
||||||
class_dict = {}
|
class_dict = {}
|
||||||
|
|
||||||
for stmt in class_def.body:
|
for stmt in class_def.body:
|
||||||
if isinstance(stmt, ast.FunctionDef):
|
if isinstance(stmt, ast.FunctionDef):
|
||||||
class_dict[stmt.name] = evaluate_function_def(
|
class_dict[stmt.name] = evaluate_function_def(stmt, state, static_tools, custom_tools, authorized_imports)
|
||||||
stmt, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
elif isinstance(stmt, ast.Assign):
|
elif isinstance(stmt, ast.Assign):
|
||||||
for target in stmt.targets:
|
for target in stmt.targets:
|
||||||
if isinstance(target, ast.Name):
|
if isinstance(target, ast.Name):
|
||||||
|
@ -351,9 +330,7 @@ def evaluate_class_def(
|
||||||
authorized_imports,
|
authorized_imports,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise InterpreterError(
|
raise InterpreterError(f"Unsupported statement in class body: {stmt.__class__.__name__}")
|
||||||
f"Unsupported statement in class body: {stmt.__class__.__name__}"
|
|
||||||
)
|
|
||||||
|
|
||||||
new_class = type(class_name, tuple(bases), class_dict)
|
new_class = type(class_name, tuple(bases), class_dict)
|
||||||
state[class_name] = new_class
|
state[class_name] = new_class
|
||||||
|
@ -371,38 +348,26 @@ def evaluate_augassign(
|
||||||
if isinstance(target, ast.Name):
|
if isinstance(target, ast.Name):
|
||||||
return state.get(target.id, 0)
|
return state.get(target.id, 0)
|
||||||
elif isinstance(target, ast.Subscript):
|
elif isinstance(target, ast.Subscript):
|
||||||
obj = evaluate_ast(
|
obj = evaluate_ast(target.value, state, static_tools, custom_tools, authorized_imports)
|
||||||
target.value, state, static_tools, custom_tools, authorized_imports
|
key = evaluate_ast(target.slice, state, static_tools, custom_tools, authorized_imports)
|
||||||
)
|
|
||||||
key = evaluate_ast(
|
|
||||||
target.slice, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
return obj[key]
|
return obj[key]
|
||||||
elif isinstance(target, ast.Attribute):
|
elif isinstance(target, ast.Attribute):
|
||||||
obj = evaluate_ast(
|
obj = evaluate_ast(target.value, state, static_tools, custom_tools, authorized_imports)
|
||||||
target.value, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
return getattr(obj, target.attr)
|
return getattr(obj, target.attr)
|
||||||
elif isinstance(target, ast.Tuple):
|
elif isinstance(target, ast.Tuple):
|
||||||
return tuple(get_current_value(elt) for elt in target.elts)
|
return tuple(get_current_value(elt) for elt in target.elts)
|
||||||
elif isinstance(target, ast.List):
|
elif isinstance(target, ast.List):
|
||||||
return [get_current_value(elt) for elt in target.elts]
|
return [get_current_value(elt) for elt in target.elts]
|
||||||
else:
|
else:
|
||||||
raise InterpreterError(
|
raise InterpreterError("AugAssign not supported for {type(target)} targets.")
|
||||||
"AugAssign not supported for {type(target)} targets."
|
|
||||||
)
|
|
||||||
|
|
||||||
current_value = get_current_value(expression.target)
|
current_value = get_current_value(expression.target)
|
||||||
value_to_add = evaluate_ast(
|
value_to_add = evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports)
|
||||||
expression.value, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(expression.op, ast.Add):
|
if isinstance(expression.op, ast.Add):
|
||||||
if isinstance(current_value, list):
|
if isinstance(current_value, list):
|
||||||
if not isinstance(value_to_add, list):
|
if not isinstance(value_to_add, list):
|
||||||
raise InterpreterError(
|
raise InterpreterError(f"Cannot add non-list value {value_to_add} to a list.")
|
||||||
f"Cannot add non-list value {value_to_add} to a list."
|
|
||||||
)
|
|
||||||
updated_value = current_value + value_to_add
|
updated_value = current_value + value_to_add
|
||||||
else:
|
else:
|
||||||
updated_value = current_value + value_to_add
|
updated_value = current_value + value_to_add
|
||||||
|
@ -429,9 +394,7 @@ def evaluate_augassign(
|
||||||
elif isinstance(expression.op, ast.RShift):
|
elif isinstance(expression.op, ast.RShift):
|
||||||
updated_value = current_value >> value_to_add
|
updated_value = current_value >> value_to_add
|
||||||
else:
|
else:
|
||||||
raise InterpreterError(
|
raise InterpreterError(f"Operation {type(expression.op).__name__} is not supported.")
|
||||||
f"Operation {type(expression.op).__name__} is not supported."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update the state
|
# Update the state
|
||||||
set_value(
|
set_value(
|
||||||
|
@ -455,16 +418,12 @@ def evaluate_boolop(
|
||||||
) -> bool:
|
) -> bool:
|
||||||
if isinstance(node.op, ast.And):
|
if isinstance(node.op, ast.And):
|
||||||
for value in node.values:
|
for value in node.values:
|
||||||
if not evaluate_ast(
|
if not evaluate_ast(value, state, static_tools, custom_tools, authorized_imports):
|
||||||
value, state, static_tools, custom_tools, authorized_imports
|
|
||||||
):
|
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
elif isinstance(node.op, ast.Or):
|
elif isinstance(node.op, ast.Or):
|
||||||
for value in node.values:
|
for value in node.values:
|
||||||
if evaluate_ast(
|
if evaluate_ast(value, state, static_tools, custom_tools, authorized_imports):
|
||||||
value, state, static_tools, custom_tools, authorized_imports
|
|
||||||
):
|
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -477,12 +436,8 @@ def evaluate_binop(
|
||||||
authorized_imports: List[str],
|
authorized_imports: List[str],
|
||||||
) -> Any:
|
) -> Any:
|
||||||
# Recursively evaluate the left and right operands
|
# Recursively evaluate the left and right operands
|
||||||
left_val = evaluate_ast(
|
left_val = evaluate_ast(binop.left, state, static_tools, custom_tools, authorized_imports)
|
||||||
binop.left, state, static_tools, custom_tools, authorized_imports
|
right_val = evaluate_ast(binop.right, state, static_tools, custom_tools, authorized_imports)
|
||||||
)
|
|
||||||
right_val = evaluate_ast(
|
|
||||||
binop.right, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
|
|
||||||
# Determine the operation based on the type of the operator in the BinOp
|
# Determine the operation based on the type of the operator in the BinOp
|
||||||
if isinstance(binop.op, ast.Add):
|
if isinstance(binop.op, ast.Add):
|
||||||
|
@ -510,9 +465,7 @@ def evaluate_binop(
|
||||||
elif isinstance(binop.op, ast.RShift):
|
elif isinstance(binop.op, ast.RShift):
|
||||||
return left_val >> right_val
|
return left_val >> right_val
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(f"Binary operation {type(binop.op).__name__} is not implemented.")
|
||||||
f"Binary operation {type(binop.op).__name__} is not implemented."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def evaluate_assign(
|
def evaluate_assign(
|
||||||
|
@ -522,17 +475,13 @@ def evaluate_assign(
|
||||||
custom_tools: Dict[str, Callable],
|
custom_tools: Dict[str, Callable],
|
||||||
authorized_imports: List[str],
|
authorized_imports: List[str],
|
||||||
) -> Any:
|
) -> Any:
|
||||||
result = evaluate_ast(
|
result = evaluate_ast(assign.value, state, static_tools, custom_tools, authorized_imports)
|
||||||
assign.value, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
if len(assign.targets) == 1:
|
if len(assign.targets) == 1:
|
||||||
target = assign.targets[0]
|
target = assign.targets[0]
|
||||||
set_value(target, result, state, static_tools, custom_tools, authorized_imports)
|
set_value(target, result, state, static_tools, custom_tools, authorized_imports)
|
||||||
else:
|
else:
|
||||||
if len(assign.targets) != len(result):
|
if len(assign.targets) != len(result):
|
||||||
raise InterpreterError(
|
raise InterpreterError(f"Assign failed: expected {len(result)} values but got {len(assign.targets)}.")
|
||||||
f"Assign failed: expected {len(result)} values but got {len(assign.targets)}."
|
|
||||||
)
|
|
||||||
expanded_values = []
|
expanded_values = []
|
||||||
for tgt in assign.targets:
|
for tgt in assign.targets:
|
||||||
if isinstance(tgt, ast.Starred):
|
if isinstance(tgt, ast.Starred):
|
||||||
|
@ -554,9 +503,7 @@ def set_value(
|
||||||
) -> None:
|
) -> None:
|
||||||
if isinstance(target, ast.Name):
|
if isinstance(target, ast.Name):
|
||||||
if target.id in static_tools:
|
if target.id in static_tools:
|
||||||
raise InterpreterError(
|
raise InterpreterError(f"Cannot assign to name '{target.id}': doing this would erase the existing tool!")
|
||||||
f"Cannot assign to name '{target.id}': doing this would erase the existing tool!"
|
|
||||||
)
|
|
||||||
state[target.id] = value
|
state[target.id] = value
|
||||||
elif isinstance(target, ast.Tuple):
|
elif isinstance(target, ast.Tuple):
|
||||||
if not isinstance(value, tuple):
|
if not isinstance(value, tuple):
|
||||||
|
@ -567,21 +514,13 @@ def set_value(
|
||||||
if len(target.elts) != len(value):
|
if len(target.elts) != len(value):
|
||||||
raise InterpreterError("Cannot unpack tuple of wrong size")
|
raise InterpreterError("Cannot unpack tuple of wrong size")
|
||||||
for i, elem in enumerate(target.elts):
|
for i, elem in enumerate(target.elts):
|
||||||
set_value(
|
set_value(elem, value[i], state, static_tools, custom_tools, authorized_imports)
|
||||||
elem, value[i], state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
elif isinstance(target, ast.Subscript):
|
elif isinstance(target, ast.Subscript):
|
||||||
obj = evaluate_ast(
|
obj = evaluate_ast(target.value, state, static_tools, custom_tools, authorized_imports)
|
||||||
target.value, state, static_tools, custom_tools, authorized_imports
|
key = evaluate_ast(target.slice, state, static_tools, custom_tools, authorized_imports)
|
||||||
)
|
|
||||||
key = evaluate_ast(
|
|
||||||
target.slice, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
obj[key] = value
|
obj[key] = value
|
||||||
elif isinstance(target, ast.Attribute):
|
elif isinstance(target, ast.Attribute):
|
||||||
obj = evaluate_ast(
|
obj = evaluate_ast(target.value, state, static_tools, custom_tools, authorized_imports)
|
||||||
target.value, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
setattr(obj, target.attr, value)
|
setattr(obj, target.attr, value)
|
||||||
|
|
||||||
|
|
||||||
|
@ -593,15 +532,11 @@ def evaluate_call(
|
||||||
authorized_imports: List[str],
|
authorized_imports: List[str],
|
||||||
) -> Any:
|
) -> Any:
|
||||||
if not (
|
if not (
|
||||||
isinstance(call.func, ast.Attribute)
|
isinstance(call.func, ast.Attribute) or isinstance(call.func, ast.Name) or isinstance(call.func, ast.Subscript)
|
||||||
or isinstance(call.func, ast.Name)
|
|
||||||
or isinstance(call.func, ast.Subscript)
|
|
||||||
):
|
):
|
||||||
raise InterpreterError(f"This is not a correct function: {call.func}).")
|
raise InterpreterError(f"This is not a correct function: {call.func}).")
|
||||||
if isinstance(call.func, ast.Attribute):
|
if isinstance(call.func, ast.Attribute):
|
||||||
obj = evaluate_ast(
|
obj = evaluate_ast(call.func.value, state, static_tools, custom_tools, authorized_imports)
|
||||||
call.func.value, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
func_name = call.func.attr
|
func_name = call.func.attr
|
||||||
if not hasattr(obj, func_name):
|
if not hasattr(obj, func_name):
|
||||||
raise InterpreterError(f"Object {obj} has no attribute {func_name}")
|
raise InterpreterError(f"Object {obj} has no attribute {func_name}")
|
||||||
|
@ -623,18 +558,12 @@ def evaluate_call(
|
||||||
)
|
)
|
||||||
|
|
||||||
elif isinstance(call.func, ast.Subscript):
|
elif isinstance(call.func, ast.Subscript):
|
||||||
value = evaluate_ast(
|
value = evaluate_ast(call.func.value, state, static_tools, custom_tools, authorized_imports)
|
||||||
call.func.value, state, static_tools, custom_tools, authorized_imports
|
index = evaluate_ast(call.func.slice, state, static_tools, custom_tools, authorized_imports)
|
||||||
)
|
|
||||||
index = evaluate_ast(
|
|
||||||
call.func.slice, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
if isinstance(value, (list, tuple)):
|
if isinstance(value, (list, tuple)):
|
||||||
func = value[index]
|
func = value[index]
|
||||||
else:
|
else:
|
||||||
raise InterpreterError(
|
raise InterpreterError(f"Cannot subscript object of type {type(value).__name__}")
|
||||||
f"Cannot subscript object of type {type(value).__name__}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not callable(func):
|
if not callable(func):
|
||||||
raise InterpreterError(f"This is not a correct function: {call.func}).")
|
raise InterpreterError(f"This is not a correct function: {call.func}).")
|
||||||
|
@ -642,20 +571,12 @@ def evaluate_call(
|
||||||
args = []
|
args = []
|
||||||
for arg in call.args:
|
for arg in call.args:
|
||||||
if isinstance(arg, ast.Starred):
|
if isinstance(arg, ast.Starred):
|
||||||
args.extend(
|
args.extend(evaluate_ast(arg.value, state, static_tools, custom_tools, authorized_imports))
|
||||||
evaluate_ast(
|
|
||||||
arg.value, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
args.append(
|
args.append(evaluate_ast(arg, state, static_tools, custom_tools, authorized_imports))
|
||||||
evaluate_ast(arg, state, static_tools, custom_tools, authorized_imports)
|
|
||||||
)
|
|
||||||
|
|
||||||
kwargs = {
|
kwargs = {
|
||||||
keyword.arg: evaluate_ast(
|
keyword.arg: evaluate_ast(keyword.value, state, static_tools, custom_tools, authorized_imports)
|
||||||
keyword.value, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
for keyword in call.keywords
|
for keyword in call.keywords
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -693,17 +614,11 @@ def evaluate_subscript(
|
||||||
custom_tools: Dict[str, Callable],
|
custom_tools: Dict[str, Callable],
|
||||||
authorized_imports: List[str],
|
authorized_imports: List[str],
|
||||||
) -> Any:
|
) -> Any:
|
||||||
index = evaluate_ast(
|
index = evaluate_ast(subscript.slice, state, static_tools, custom_tools, authorized_imports)
|
||||||
subscript.slice, state, static_tools, custom_tools, authorized_imports
|
value = evaluate_ast(subscript.value, state, static_tools, custom_tools, authorized_imports)
|
||||||
)
|
|
||||||
value = evaluate_ast(
|
|
||||||
subscript.value, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(value, str) and isinstance(index, str):
|
if isinstance(value, str) and isinstance(index, str):
|
||||||
raise InterpreterError(
|
raise InterpreterError("You're trying to subscript a string with a string index, which is impossible")
|
||||||
"You're trying to subscript a string with a string index, which is impossible"
|
|
||||||
)
|
|
||||||
if isinstance(value, pd.core.indexing._LocIndexer):
|
if isinstance(value, pd.core.indexing._LocIndexer):
|
||||||
parent_object = value.obj
|
parent_object = value.obj
|
||||||
return parent_object.loc[index]
|
return parent_object.loc[index]
|
||||||
|
@ -718,15 +633,11 @@ def evaluate_subscript(
|
||||||
return value[index]
|
return value[index]
|
||||||
elif isinstance(value, (list, tuple)):
|
elif isinstance(value, (list, tuple)):
|
||||||
if not (-len(value) <= index < len(value)):
|
if not (-len(value) <= index < len(value)):
|
||||||
raise InterpreterError(
|
raise InterpreterError(f"Index {index} out of bounds for list of length {len(value)}")
|
||||||
f"Index {index} out of bounds for list of length {len(value)}"
|
|
||||||
)
|
|
||||||
return value[int(index)]
|
return value[int(index)]
|
||||||
elif isinstance(value, str):
|
elif isinstance(value, str):
|
||||||
if not (-len(value) <= index < len(value)):
|
if not (-len(value) <= index < len(value)):
|
||||||
raise InterpreterError(
|
raise InterpreterError(f"Index {index} out of bounds for string of length {len(value)}")
|
||||||
f"Index {index} out of bounds for string of length {len(value)}"
|
|
||||||
)
|
|
||||||
return value[index]
|
return value[index]
|
||||||
elif index in value:
|
elif index in value:
|
||||||
return value[index]
|
return value[index]
|
||||||
|
@ -765,12 +676,9 @@ def evaluate_condition(
|
||||||
custom_tools: Dict[str, Callable],
|
custom_tools: Dict[str, Callable],
|
||||||
authorized_imports: List[str],
|
authorized_imports: List[str],
|
||||||
) -> bool:
|
) -> bool:
|
||||||
left = evaluate_ast(
|
left = evaluate_ast(condition.left, state, static_tools, custom_tools, authorized_imports)
|
||||||
condition.left, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
comparators = [
|
comparators = [
|
||||||
evaluate_ast(c, state, static_tools, custom_tools, authorized_imports)
|
evaluate_ast(c, state, static_tools, custom_tools, authorized_imports) for c in condition.comparators
|
||||||
for c in condition.comparators
|
|
||||||
]
|
]
|
||||||
ops = [type(op) for op in condition.ops]
|
ops = [type(op) for op in condition.ops]
|
||||||
|
|
||||||
|
@ -818,21 +726,15 @@ def evaluate_if(
|
||||||
authorized_imports: List[str],
|
authorized_imports: List[str],
|
||||||
) -> Any:
|
) -> Any:
|
||||||
result = None
|
result = None
|
||||||
test_result = evaluate_ast(
|
test_result = evaluate_ast(if_statement.test, state, static_tools, custom_tools, authorized_imports)
|
||||||
if_statement.test, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
if test_result:
|
if test_result:
|
||||||
for line in if_statement.body:
|
for line in if_statement.body:
|
||||||
line_result = evaluate_ast(
|
line_result = evaluate_ast(line, state, static_tools, custom_tools, authorized_imports)
|
||||||
line, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
if line_result is not None:
|
if line_result is not None:
|
||||||
result = line_result
|
result = line_result
|
||||||
else:
|
else:
|
||||||
for line in if_statement.orelse:
|
for line in if_statement.orelse:
|
||||||
line_result = evaluate_ast(
|
line_result = evaluate_ast(line, state, static_tools, custom_tools, authorized_imports)
|
||||||
line, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
if line_result is not None:
|
if line_result is not None:
|
||||||
result = line_result
|
result = line_result
|
||||||
return result
|
return result
|
||||||
|
@ -846,9 +748,7 @@ def evaluate_for(
|
||||||
authorized_imports: List[str],
|
authorized_imports: List[str],
|
||||||
) -> Any:
|
) -> Any:
|
||||||
result = None
|
result = None
|
||||||
iterator = evaluate_ast(
|
iterator = evaluate_ast(for_loop.iter, state, static_tools, custom_tools, authorized_imports)
|
||||||
for_loop.iter, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
for counter in iterator:
|
for counter in iterator:
|
||||||
set_value(
|
set_value(
|
||||||
for_loop.target,
|
for_loop.target,
|
||||||
|
@ -860,9 +760,7 @@ def evaluate_for(
|
||||||
)
|
)
|
||||||
for node in for_loop.body:
|
for node in for_loop.body:
|
||||||
try:
|
try:
|
||||||
line_result = evaluate_ast(
|
line_result = evaluate_ast(node, state, static_tools, custom_tools, authorized_imports)
|
||||||
node, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
if line_result is not None:
|
if line_result is not None:
|
||||||
result = line_result
|
result = line_result
|
||||||
except BreakException:
|
except BreakException:
|
||||||
|
@ -882,9 +780,7 @@ def evaluate_listcomp(
|
||||||
custom_tools: Dict[str, Callable],
|
custom_tools: Dict[str, Callable],
|
||||||
authorized_imports: List[str],
|
authorized_imports: List[str],
|
||||||
) -> List[Any]:
|
) -> List[Any]:
|
||||||
def inner_evaluate(
|
def inner_evaluate(generators: List[ast.comprehension], index: int, current_state: Dict[str, Any]) -> List[Any]:
|
||||||
generators: List[ast.comprehension], index: int, current_state: Dict[str, Any]
|
|
||||||
) -> List[Any]:
|
|
||||||
if index >= len(generators):
|
if index >= len(generators):
|
||||||
return [
|
return [
|
||||||
evaluate_ast(
|
evaluate_ast(
|
||||||
|
@ -912,9 +808,7 @@ def evaluate_listcomp(
|
||||||
else:
|
else:
|
||||||
new_state[generator.target.id] = value
|
new_state[generator.target.id] = value
|
||||||
if all(
|
if all(
|
||||||
evaluate_ast(
|
evaluate_ast(if_clause, new_state, static_tools, custom_tools, authorized_imports)
|
||||||
if_clause, new_state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
for if_clause in generator.ifs
|
for if_clause in generator.ifs
|
||||||
):
|
):
|
||||||
result.extend(inner_evaluate(generators, index + 1, new_state))
|
result.extend(inner_evaluate(generators, index + 1, new_state))
|
||||||
|
@ -938,32 +832,24 @@ def evaluate_try(
|
||||||
for handler in try_node.handlers:
|
for handler in try_node.handlers:
|
||||||
if handler.type is None or isinstance(
|
if handler.type is None or isinstance(
|
||||||
e,
|
e,
|
||||||
evaluate_ast(
|
evaluate_ast(handler.type, state, static_tools, custom_tools, authorized_imports),
|
||||||
handler.type, state, static_tools, custom_tools, authorized_imports
|
|
||||||
),
|
|
||||||
):
|
):
|
||||||
matched = True
|
matched = True
|
||||||
if handler.name:
|
if handler.name:
|
||||||
state[handler.name] = e
|
state[handler.name] = e
|
||||||
for stmt in handler.body:
|
for stmt in handler.body:
|
||||||
evaluate_ast(
|
evaluate_ast(stmt, state, static_tools, custom_tools, authorized_imports)
|
||||||
stmt, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
break
|
break
|
||||||
if not matched:
|
if not matched:
|
||||||
raise e
|
raise e
|
||||||
else:
|
else:
|
||||||
if try_node.orelse:
|
if try_node.orelse:
|
||||||
for stmt in try_node.orelse:
|
for stmt in try_node.orelse:
|
||||||
evaluate_ast(
|
evaluate_ast(stmt, state, static_tools, custom_tools, authorized_imports)
|
||||||
stmt, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
finally:
|
finally:
|
||||||
if try_node.finalbody:
|
if try_node.finalbody:
|
||||||
for stmt in try_node.finalbody:
|
for stmt in try_node.finalbody:
|
||||||
evaluate_ast(
|
evaluate_ast(stmt, state, static_tools, custom_tools, authorized_imports)
|
||||||
stmt, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def evaluate_raise(
|
def evaluate_raise(
|
||||||
|
@ -974,15 +860,11 @@ def evaluate_raise(
|
||||||
authorized_imports: List[str],
|
authorized_imports: List[str],
|
||||||
) -> None:
|
) -> None:
|
||||||
if raise_node.exc is not None:
|
if raise_node.exc is not None:
|
||||||
exc = evaluate_ast(
|
exc = evaluate_ast(raise_node.exc, state, static_tools, custom_tools, authorized_imports)
|
||||||
raise_node.exc, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
exc = None
|
exc = None
|
||||||
if raise_node.cause is not None:
|
if raise_node.cause is not None:
|
||||||
cause = evaluate_ast(
|
cause = evaluate_ast(raise_node.cause, state, static_tools, custom_tools, authorized_imports)
|
||||||
raise_node.cause, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
cause = None
|
cause = None
|
||||||
if exc is not None:
|
if exc is not None:
|
||||||
|
@ -1001,14 +883,10 @@ def evaluate_assert(
|
||||||
custom_tools: Dict[str, Callable],
|
custom_tools: Dict[str, Callable],
|
||||||
authorized_imports: List[str],
|
authorized_imports: List[str],
|
||||||
) -> None:
|
) -> None:
|
||||||
test_result = evaluate_ast(
|
test_result = evaluate_ast(assert_node.test, state, static_tools, custom_tools, authorized_imports)
|
||||||
assert_node.test, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
if not test_result:
|
if not test_result:
|
||||||
if assert_node.msg:
|
if assert_node.msg:
|
||||||
msg = evaluate_ast(
|
msg = evaluate_ast(assert_node.msg, state, static_tools, custom_tools, authorized_imports)
|
||||||
assert_node.msg, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
raise AssertionError(msg)
|
raise AssertionError(msg)
|
||||||
else:
|
else:
|
||||||
# Include the failing condition in the assertion message
|
# Include the failing condition in the assertion message
|
||||||
|
@ -1025,9 +903,7 @@ def evaluate_with(
|
||||||
) -> None:
|
) -> None:
|
||||||
contexts = []
|
contexts = []
|
||||||
for item in with_node.items:
|
for item in with_node.items:
|
||||||
context_expr = evaluate_ast(
|
context_expr = evaluate_ast(item.context_expr, state, static_tools, custom_tools, authorized_imports)
|
||||||
item.context_expr, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
if item.optional_vars:
|
if item.optional_vars:
|
||||||
state[item.optional_vars.id] = context_expr.__enter__()
|
state[item.optional_vars.id] = context_expr.__enter__()
|
||||||
contexts.append(state[item.optional_vars.id])
|
contexts.append(state[item.optional_vars.id])
|
||||||
|
@ -1069,19 +945,14 @@ def get_safe_module(unsafe_module, dangerous_patterns, visited=None):
|
||||||
# Copy all attributes by reference, recursively checking modules
|
# Copy all attributes by reference, recursively checking modules
|
||||||
for attr_name in dir(unsafe_module):
|
for attr_name in dir(unsafe_module):
|
||||||
# Skip dangerous patterns at any level
|
# Skip dangerous patterns at any level
|
||||||
if any(
|
if any(pattern in f"{unsafe_module.__name__}.{attr_name}" for pattern in dangerous_patterns):
|
||||||
pattern in f"{unsafe_module.__name__}.{attr_name}"
|
|
||||||
for pattern in dangerous_patterns
|
|
||||||
):
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
attr_value = getattr(unsafe_module, attr_name)
|
attr_value = getattr(unsafe_module, attr_name)
|
||||||
|
|
||||||
# Recursively process nested modules, passing visited set
|
# Recursively process nested modules, passing visited set
|
||||||
if isinstance(attr_value, ModuleType):
|
if isinstance(attr_value, ModuleType):
|
||||||
attr_value = get_safe_module(
|
attr_value = get_safe_module(attr_value, dangerous_patterns, visited=visited)
|
||||||
attr_value, dangerous_patterns, visited=visited
|
|
||||||
)
|
|
||||||
|
|
||||||
setattr(safe_module, attr_name, attr_value)
|
setattr(safe_module, attr_name, attr_value)
|
||||||
|
|
||||||
|
@ -1116,18 +987,14 @@ def import_modules(expression, state, authorized_imports):
|
||||||
module_path = module_name.split(".")
|
module_path = module_name.split(".")
|
||||||
if any([module in dangerous_patterns for module in module_path]):
|
if any([module in dangerous_patterns for module in module_path]):
|
||||||
return False
|
return False
|
||||||
module_subpaths = [
|
module_subpaths = [".".join(module_path[:i]) for i in range(1, len(module_path) + 1)]
|
||||||
".".join(module_path[:i]) for i in range(1, len(module_path) + 1)
|
|
||||||
]
|
|
||||||
return any(subpath in authorized_imports for subpath in module_subpaths)
|
return any(subpath in authorized_imports for subpath in module_subpaths)
|
||||||
|
|
||||||
if isinstance(expression, ast.Import):
|
if isinstance(expression, ast.Import):
|
||||||
for alias in expression.names:
|
for alias in expression.names:
|
||||||
if check_module_authorized(alias.name):
|
if check_module_authorized(alias.name):
|
||||||
raw_module = import_module(alias.name)
|
raw_module = import_module(alias.name)
|
||||||
state[alias.asname or alias.name] = get_safe_module(
|
state[alias.asname or alias.name] = get_safe_module(raw_module, dangerous_patterns)
|
||||||
raw_module, dangerous_patterns
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise InterpreterError(
|
raise InterpreterError(
|
||||||
f"Import of {alias.name} is not allowed. Authorized imports are: {str(authorized_imports)}"
|
f"Import of {alias.name} is not allowed. Authorized imports are: {str(authorized_imports)}"
|
||||||
|
@ -1135,9 +1002,7 @@ def import_modules(expression, state, authorized_imports):
|
||||||
return None
|
return None
|
||||||
elif isinstance(expression, ast.ImportFrom):
|
elif isinstance(expression, ast.ImportFrom):
|
||||||
if check_module_authorized(expression.module):
|
if check_module_authorized(expression.module):
|
||||||
raw_module = __import__(
|
raw_module = __import__(expression.module, fromlist=[alias.name for alias in expression.names])
|
||||||
expression.module, fromlist=[alias.name for alias in expression.names]
|
|
||||||
)
|
|
||||||
for alias in expression.names:
|
for alias in expression.names:
|
||||||
state[alias.asname or alias.name] = get_safe_module(
|
state[alias.asname or alias.name] = get_safe_module(
|
||||||
getattr(raw_module, alias.name), dangerous_patterns
|
getattr(raw_module, alias.name), dangerous_patterns
|
||||||
|
@ -1156,9 +1021,7 @@ def evaluate_dictcomp(
|
||||||
) -> Dict[Any, Any]:
|
) -> Dict[Any, Any]:
|
||||||
result = {}
|
result = {}
|
||||||
for gen in dictcomp.generators:
|
for gen in dictcomp.generators:
|
||||||
iter_value = evaluate_ast(
|
iter_value = evaluate_ast(gen.iter, state, static_tools, custom_tools, authorized_imports)
|
||||||
gen.iter, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
for value in iter_value:
|
for value in iter_value:
|
||||||
new_state = state.copy()
|
new_state = state.copy()
|
||||||
set_value(
|
set_value(
|
||||||
|
@ -1170,9 +1033,7 @@ def evaluate_dictcomp(
|
||||||
authorized_imports,
|
authorized_imports,
|
||||||
)
|
)
|
||||||
if all(
|
if all(
|
||||||
evaluate_ast(
|
evaluate_ast(if_clause, new_state, static_tools, custom_tools, authorized_imports)
|
||||||
if_clause, new_state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
for if_clause in gen.ifs
|
for if_clause in gen.ifs
|
||||||
):
|
):
|
||||||
key = evaluate_ast(
|
key = evaluate_ast(
|
||||||
|
@ -1229,202 +1090,116 @@ def evaluate_ast(
|
||||||
if isinstance(expression, ast.Assign):
|
if isinstance(expression, ast.Assign):
|
||||||
# Assignment -> we evaluate the assignment which should update the state
|
# Assignment -> we evaluate the assignment which should update the state
|
||||||
# We return the variable assigned as it may be used to determine the final result.
|
# We return the variable assigned as it may be used to determine the final result.
|
||||||
return evaluate_assign(
|
return evaluate_assign(expression, state, static_tools, custom_tools, authorized_imports)
|
||||||
expression, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
elif isinstance(expression, ast.AugAssign):
|
elif isinstance(expression, ast.AugAssign):
|
||||||
return evaluate_augassign(
|
return evaluate_augassign(expression, state, static_tools, custom_tools, authorized_imports)
|
||||||
expression, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
elif isinstance(expression, ast.Call):
|
elif isinstance(expression, ast.Call):
|
||||||
# Function call -> we return the value of the function call
|
# Function call -> we return the value of the function call
|
||||||
return evaluate_call(
|
return evaluate_call(expression, state, static_tools, custom_tools, authorized_imports)
|
||||||
expression, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
elif isinstance(expression, ast.Constant):
|
elif isinstance(expression, ast.Constant):
|
||||||
# Constant -> just return the value
|
# Constant -> just return the value
|
||||||
return expression.value
|
return expression.value
|
||||||
elif isinstance(expression, ast.Tuple):
|
elif isinstance(expression, ast.Tuple):
|
||||||
return tuple(
|
return tuple(
|
||||||
evaluate_ast(elt, state, static_tools, custom_tools, authorized_imports)
|
evaluate_ast(elt, state, static_tools, custom_tools, authorized_imports) for elt in expression.elts
|
||||||
for elt in expression.elts
|
|
||||||
)
|
)
|
||||||
elif isinstance(expression, (ast.ListComp, ast.GeneratorExp)):
|
elif isinstance(expression, (ast.ListComp, ast.GeneratorExp)):
|
||||||
return evaluate_listcomp(
|
return evaluate_listcomp(expression, state, static_tools, custom_tools, authorized_imports)
|
||||||
expression, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
elif isinstance(expression, ast.UnaryOp):
|
elif isinstance(expression, ast.UnaryOp):
|
||||||
return evaluate_unaryop(
|
return evaluate_unaryop(expression, state, static_tools, custom_tools, authorized_imports)
|
||||||
expression, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
elif isinstance(expression, ast.Starred):
|
elif isinstance(expression, ast.Starred):
|
||||||
return evaluate_ast(
|
return evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports)
|
||||||
expression.value, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
elif isinstance(expression, ast.BoolOp):
|
elif isinstance(expression, ast.BoolOp):
|
||||||
# Boolean operation -> evaluate the operation
|
# Boolean operation -> evaluate the operation
|
||||||
return evaluate_boolop(
|
return evaluate_boolop(expression, state, static_tools, custom_tools, authorized_imports)
|
||||||
expression, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
elif isinstance(expression, ast.Break):
|
elif isinstance(expression, ast.Break):
|
||||||
raise BreakException()
|
raise BreakException()
|
||||||
elif isinstance(expression, ast.Continue):
|
elif isinstance(expression, ast.Continue):
|
||||||
raise ContinueException()
|
raise ContinueException()
|
||||||
elif isinstance(expression, ast.BinOp):
|
elif isinstance(expression, ast.BinOp):
|
||||||
# Binary operation -> execute operation
|
# Binary operation -> execute operation
|
||||||
return evaluate_binop(
|
return evaluate_binop(expression, state, static_tools, custom_tools, authorized_imports)
|
||||||
expression, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
elif isinstance(expression, ast.Compare):
|
elif isinstance(expression, ast.Compare):
|
||||||
# Comparison -> evaluate the comparison
|
# Comparison -> evaluate the comparison
|
||||||
return evaluate_condition(
|
return evaluate_condition(expression, state, static_tools, custom_tools, authorized_imports)
|
||||||
expression, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
elif isinstance(expression, ast.Lambda):
|
elif isinstance(expression, ast.Lambda):
|
||||||
return evaluate_lambda(
|
return evaluate_lambda(expression, state, static_tools, custom_tools, authorized_imports)
|
||||||
expression, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
elif isinstance(expression, ast.FunctionDef):
|
elif isinstance(expression, ast.FunctionDef):
|
||||||
return evaluate_function_def(
|
return evaluate_function_def(expression, state, static_tools, custom_tools, authorized_imports)
|
||||||
expression, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
elif isinstance(expression, ast.Dict):
|
elif isinstance(expression, ast.Dict):
|
||||||
# Dict -> evaluate all keys and values
|
# Dict -> evaluate all keys and values
|
||||||
keys = [
|
keys = [evaluate_ast(k, state, static_tools, custom_tools, authorized_imports) for k in expression.keys]
|
||||||
evaluate_ast(k, state, static_tools, custom_tools, authorized_imports)
|
values = [evaluate_ast(v, state, static_tools, custom_tools, authorized_imports) for v in expression.values]
|
||||||
for k in expression.keys
|
|
||||||
]
|
|
||||||
values = [
|
|
||||||
evaluate_ast(v, state, static_tools, custom_tools, authorized_imports)
|
|
||||||
for v in expression.values
|
|
||||||
]
|
|
||||||
return dict(zip(keys, values))
|
return dict(zip(keys, values))
|
||||||
elif isinstance(expression, ast.Expr):
|
elif isinstance(expression, ast.Expr):
|
||||||
# Expression -> evaluate the content
|
# Expression -> evaluate the content
|
||||||
return evaluate_ast(
|
return evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports)
|
||||||
expression.value, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
elif isinstance(expression, ast.For):
|
elif isinstance(expression, ast.For):
|
||||||
# For loop -> execute the loop
|
# For loop -> execute the loop
|
||||||
return evaluate_for(
|
return evaluate_for(expression, state, static_tools, custom_tools, authorized_imports)
|
||||||
expression, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
elif isinstance(expression, ast.FormattedValue):
|
elif isinstance(expression, ast.FormattedValue):
|
||||||
# Formatted value (part of f-string) -> evaluate the content and return
|
# Formatted value (part of f-string) -> evaluate the content and return
|
||||||
return evaluate_ast(
|
return evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports)
|
||||||
expression.value, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
elif isinstance(expression, ast.If):
|
elif isinstance(expression, ast.If):
|
||||||
# If -> execute the right branch
|
# If -> execute the right branch
|
||||||
return evaluate_if(
|
return evaluate_if(expression, state, static_tools, custom_tools, authorized_imports)
|
||||||
expression, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
elif hasattr(ast, "Index") and isinstance(expression, ast.Index):
|
elif hasattr(ast, "Index") and isinstance(expression, ast.Index):
|
||||||
return evaluate_ast(
|
return evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports)
|
||||||
expression.value, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
elif isinstance(expression, ast.JoinedStr):
|
elif isinstance(expression, ast.JoinedStr):
|
||||||
return "".join(
|
return "".join(
|
||||||
[
|
[str(evaluate_ast(v, state, static_tools, custom_tools, authorized_imports)) for v in expression.values]
|
||||||
str(
|
|
||||||
evaluate_ast(
|
|
||||||
v, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
)
|
|
||||||
for v in expression.values
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
elif isinstance(expression, ast.List):
|
elif isinstance(expression, ast.List):
|
||||||
# List -> evaluate all elements
|
# List -> evaluate all elements
|
||||||
return [
|
return [evaluate_ast(elt, state, static_tools, custom_tools, authorized_imports) for elt in expression.elts]
|
||||||
evaluate_ast(elt, state, static_tools, custom_tools, authorized_imports)
|
|
||||||
for elt in expression.elts
|
|
||||||
]
|
|
||||||
elif isinstance(expression, ast.Name):
|
elif isinstance(expression, ast.Name):
|
||||||
# Name -> pick up the value in the state
|
# Name -> pick up the value in the state
|
||||||
return evaluate_name(
|
return evaluate_name(expression, state, static_tools, custom_tools, authorized_imports)
|
||||||
expression, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
elif isinstance(expression, ast.Subscript):
|
elif isinstance(expression, ast.Subscript):
|
||||||
# Subscript -> return the value of the indexing
|
# Subscript -> return the value of the indexing
|
||||||
return evaluate_subscript(
|
return evaluate_subscript(expression, state, static_tools, custom_tools, authorized_imports)
|
||||||
expression, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
elif isinstance(expression, ast.IfExp):
|
elif isinstance(expression, ast.IfExp):
|
||||||
test_val = evaluate_ast(
|
test_val = evaluate_ast(expression.test, state, static_tools, custom_tools, authorized_imports)
|
||||||
expression.test, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
if test_val:
|
if test_val:
|
||||||
return evaluate_ast(
|
return evaluate_ast(expression.body, state, static_tools, custom_tools, authorized_imports)
|
||||||
expression.body, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return evaluate_ast(
|
return evaluate_ast(expression.orelse, state, static_tools, custom_tools, authorized_imports)
|
||||||
expression.orelse, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
elif isinstance(expression, ast.Attribute):
|
elif isinstance(expression, ast.Attribute):
|
||||||
value = evaluate_ast(
|
value = evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports)
|
||||||
expression.value, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
return getattr(value, expression.attr)
|
return getattr(value, expression.attr)
|
||||||
elif isinstance(expression, ast.Slice):
|
elif isinstance(expression, ast.Slice):
|
||||||
return slice(
|
return slice(
|
||||||
evaluate_ast(
|
evaluate_ast(expression.lower, state, static_tools, custom_tools, authorized_imports)
|
||||||
expression.lower, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
if expression.lower is not None
|
if expression.lower is not None
|
||||||
else None,
|
else None,
|
||||||
evaluate_ast(
|
evaluate_ast(expression.upper, state, static_tools, custom_tools, authorized_imports)
|
||||||
expression.upper, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
if expression.upper is not None
|
if expression.upper is not None
|
||||||
else None,
|
else None,
|
||||||
evaluate_ast(
|
evaluate_ast(expression.step, state, static_tools, custom_tools, authorized_imports)
|
||||||
expression.step, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
if expression.step is not None
|
if expression.step is not None
|
||||||
else None,
|
else None,
|
||||||
)
|
)
|
||||||
elif isinstance(expression, ast.DictComp):
|
elif isinstance(expression, ast.DictComp):
|
||||||
return evaluate_dictcomp(
|
return evaluate_dictcomp(expression, state, static_tools, custom_tools, authorized_imports)
|
||||||
expression, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
elif isinstance(expression, ast.While):
|
elif isinstance(expression, ast.While):
|
||||||
return evaluate_while(
|
return evaluate_while(expression, state, static_tools, custom_tools, authorized_imports)
|
||||||
expression, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
elif isinstance(expression, (ast.Import, ast.ImportFrom)):
|
elif isinstance(expression, (ast.Import, ast.ImportFrom)):
|
||||||
return import_modules(expression, state, authorized_imports)
|
return import_modules(expression, state, authorized_imports)
|
||||||
elif isinstance(expression, ast.ClassDef):
|
elif isinstance(expression, ast.ClassDef):
|
||||||
return evaluate_class_def(
|
return evaluate_class_def(expression, state, static_tools, custom_tools, authorized_imports)
|
||||||
expression, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
elif isinstance(expression, ast.Try):
|
elif isinstance(expression, ast.Try):
|
||||||
return evaluate_try(
|
return evaluate_try(expression, state, static_tools, custom_tools, authorized_imports)
|
||||||
expression, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
elif isinstance(expression, ast.Raise):
|
elif isinstance(expression, ast.Raise):
|
||||||
return evaluate_raise(
|
return evaluate_raise(expression, state, static_tools, custom_tools, authorized_imports)
|
||||||
expression, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
elif isinstance(expression, ast.Assert):
|
elif isinstance(expression, ast.Assert):
|
||||||
return evaluate_assert(
|
return evaluate_assert(expression, state, static_tools, custom_tools, authorized_imports)
|
||||||
expression, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
elif isinstance(expression, ast.With):
|
elif isinstance(expression, ast.With):
|
||||||
return evaluate_with(
|
return evaluate_with(expression, state, static_tools, custom_tools, authorized_imports)
|
||||||
expression, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
elif isinstance(expression, ast.Set):
|
elif isinstance(expression, ast.Set):
|
||||||
return {
|
return {evaluate_ast(elt, state, static_tools, custom_tools, authorized_imports) for elt in expression.elts}
|
||||||
evaluate_ast(elt, state, static_tools, custom_tools, authorized_imports)
|
|
||||||
for elt in expression.elts
|
|
||||||
}
|
|
||||||
elif isinstance(expression, ast.Return):
|
elif isinstance(expression, ast.Return):
|
||||||
raise ReturnException(
|
raise ReturnException(
|
||||||
evaluate_ast(
|
evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports)
|
||||||
expression.value, state, static_tools, custom_tools, authorized_imports
|
|
||||||
)
|
|
||||||
if expression.value
|
if expression.value
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
@ -1488,18 +1263,12 @@ def evaluate_python_code(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for node in expression.body:
|
for node in expression.body:
|
||||||
result = evaluate_ast(
|
result = evaluate_ast(node, state, static_tools, custom_tools, authorized_imports)
|
||||||
node, state, static_tools, custom_tools, authorized_imports
|
state["print_outputs"] = truncate_content(PRINT_OUTPUTS, max_length=max_print_outputs_length)
|
||||||
)
|
|
||||||
state["print_outputs"] = truncate_content(
|
|
||||||
PRINT_OUTPUTS, max_length=max_print_outputs_length
|
|
||||||
)
|
|
||||||
is_final_answer = False
|
is_final_answer = False
|
||||||
return result, is_final_answer
|
return result, is_final_answer
|
||||||
except FinalAnswerException as e:
|
except FinalAnswerException as e:
|
||||||
state["print_outputs"] = truncate_content(
|
state["print_outputs"] = truncate_content(PRINT_OUTPUTS, max_length=max_print_outputs_length)
|
||||||
PRINT_OUTPUTS, max_length=max_print_outputs_length
|
|
||||||
)
|
|
||||||
is_final_answer = True
|
is_final_answer = True
|
||||||
return e.value, is_final_answer
|
return e.value, is_final_answer
|
||||||
except InterpreterError as e:
|
except InterpreterError as e:
|
||||||
|
@ -1521,9 +1290,7 @@ class LocalPythonInterpreter:
|
||||||
if max_print_outputs_length is None:
|
if max_print_outputs_length is None:
|
||||||
self.max_print_outputs_length = DEFAULT_MAX_LEN_OUTPUT
|
self.max_print_outputs_length = DEFAULT_MAX_LEN_OUTPUT
|
||||||
self.additional_authorized_imports = additional_authorized_imports
|
self.additional_authorized_imports = additional_authorized_imports
|
||||||
self.authorized_imports = list(
|
self.authorized_imports = list(set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports))
|
||||||
set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports)
|
|
||||||
)
|
|
||||||
# Add base trusted tools to list
|
# Add base trusted tools to list
|
||||||
self.static_tools = {
|
self.static_tools = {
|
||||||
**tools,
|
**tools,
|
||||||
|
@ -1531,9 +1298,7 @@ class LocalPythonInterpreter:
|
||||||
}
|
}
|
||||||
# TODO: assert self.authorized imports are all installed locally
|
# TODO: assert self.authorized imports are all installed locally
|
||||||
|
|
||||||
def __call__(
|
def __call__(self, code_action: str, additional_variables: Dict) -> Tuple[Any, str, bool]:
|
||||||
self, code_action: str, additional_variables: Dict
|
|
||||||
) -> Tuple[Any, str, bool]:
|
|
||||||
self.state.update(additional_variables)
|
self.state.update(additional_variables)
|
||||||
output, is_final_answer = evaluate_python_code(
|
output, is_final_answer = evaluate_python_code(
|
||||||
code_action,
|
code_action,
|
||||||
|
|
|
@ -14,17 +14,16 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from dataclasses import dataclass, asdict
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
from dataclasses import asdict, dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, List, Optional, Union, Any
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from huggingface_hub import InferenceClient
|
from huggingface_hub import InferenceClient
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
|
@ -35,6 +34,7 @@ from transformers import (
|
||||||
|
|
||||||
from .tools import Tool
|
from .tools import Tool
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
DEFAULT_JSONAGENT_REGEX_GRAMMAR = {
|
DEFAULT_JSONAGENT_REGEX_GRAMMAR = {
|
||||||
|
@ -100,10 +100,7 @@ class ChatMessage:
|
||||||
def from_hf_api(cls, message) -> "ChatMessage":
|
def from_hf_api(cls, message) -> "ChatMessage":
|
||||||
tool_calls = None
|
tool_calls = None
|
||||||
if getattr(message, "tool_calls", None) is not None:
|
if getattr(message, "tool_calls", None) is not None:
|
||||||
tool_calls = [
|
tool_calls = [ChatMessageToolCall.from_hf_api(tool_call) for tool_call in message.tool_calls]
|
||||||
ChatMessageToolCall.from_hf_api(tool_call)
|
|
||||||
for tool_call in message.tool_calls
|
|
||||||
]
|
|
||||||
return cls(role=message.role, content=message.content, tool_calls=tool_calls)
|
return cls(role=message.role, content=message.content, tool_calls=tool_calls)
|
||||||
|
|
||||||
|
|
||||||
|
@ -172,17 +169,12 @@ def get_clean_message_list(
|
||||||
|
|
||||||
role = message["role"]
|
role = message["role"]
|
||||||
if role not in MessageRole.roles():
|
if role not in MessageRole.roles():
|
||||||
raise ValueError(
|
raise ValueError(f"Incorrect role {role}, only {MessageRole.roles()} are supported for now.")
|
||||||
f"Incorrect role {role}, only {MessageRole.roles()} are supported for now."
|
|
||||||
)
|
|
||||||
|
|
||||||
if role in role_conversions:
|
if role in role_conversions:
|
||||||
message["role"] = role_conversions[role]
|
message["role"] = role_conversions[role]
|
||||||
|
|
||||||
if (
|
if len(final_message_list) > 0 and message["role"] == final_message_list[-1]["role"]:
|
||||||
len(final_message_list) > 0
|
|
||||||
and message["role"] == final_message_list[-1]["role"]
|
|
||||||
):
|
|
||||||
final_message_list[-1]["content"] += "\n=======\n" + message["content"]
|
final_message_list[-1]["content"] += "\n=======\n" + message["content"]
|
||||||
else:
|
else:
|
||||||
final_message_list.append(message)
|
final_message_list.append(message)
|
||||||
|
@ -292,9 +284,7 @@ class HfApiModel(Model):
|
||||||
Gets an LLM output message for the given list of input messages.
|
Gets an LLM output message for the given list of input messages.
|
||||||
If argument `tools_to_call_from` is passed, the model's tool calling options will be used to return a tool call.
|
If argument `tools_to_call_from` is passed, the model's tool calling options will be used to return a tool call.
|
||||||
"""
|
"""
|
||||||
messages = get_clean_message_list(
|
messages = get_clean_message_list(messages, role_conversions=tool_role_conversions)
|
||||||
messages, role_conversions=tool_role_conversions
|
|
||||||
)
|
|
||||||
if tools_to_call_from:
|
if tools_to_call_from:
|
||||||
response = self.client.chat.completions.create(
|
response = self.client.chat.completions.create(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
@ -367,9 +357,7 @@ class TransformersModel(Model):
|
||||||
default_model_id = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
|
default_model_id = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
|
||||||
if model_id is None:
|
if model_id is None:
|
||||||
model_id = default_model_id
|
model_id = default_model_id
|
||||||
logger.warning(
|
logger.warning(f"`model_id`not provided, using this default tokenizer for token counts: '{model_id}'")
|
||||||
f"`model_id`not provided, using this default tokenizer for token counts: '{model_id}'"
|
|
||||||
)
|
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
if device_map is None:
|
if device_map is None:
|
||||||
|
@ -389,9 +377,7 @@ class TransformersModel(Model):
|
||||||
)
|
)
|
||||||
self.model_id = default_model_id
|
self.model_id = default_model_id
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(default_model_id)
|
self.tokenizer = AutoTokenizer.from_pretrained(default_model_id)
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(
|
self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device_map, torch_dtype=torch_dtype)
|
||||||
model_id, device_map=device_map, torch_dtype=torch_dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
def make_stopping_criteria(self, stop_sequences: List[str]) -> StoppingCriteriaList:
|
def make_stopping_criteria(self, stop_sequences: List[str]) -> StoppingCriteriaList:
|
||||||
class StopOnStrings(StoppingCriteria):
|
class StopOnStrings(StoppingCriteria):
|
||||||
|
@ -404,16 +390,9 @@ class TransformersModel(Model):
|
||||||
self.stream = ""
|
self.stream = ""
|
||||||
|
|
||||||
def __call__(self, input_ids, scores, **kwargs):
|
def __call__(self, input_ids, scores, **kwargs):
|
||||||
generated = self.tokenizer.decode(
|
generated = self.tokenizer.decode(input_ids[0][-1], skip_special_tokens=True)
|
||||||
input_ids[0][-1], skip_special_tokens=True
|
|
||||||
)
|
|
||||||
self.stream += generated
|
self.stream += generated
|
||||||
if any(
|
if any([self.stream.endswith(stop_string) for stop_string in self.stop_strings]):
|
||||||
[
|
|
||||||
self.stream.endswith(stop_string)
|
|
||||||
for stop_string in self.stop_strings
|
|
||||||
]
|
|
||||||
):
|
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -426,9 +405,7 @@ class TransformersModel(Model):
|
||||||
grammar: Optional[str] = None,
|
grammar: Optional[str] = None,
|
||||||
tools_to_call_from: Optional[List[Tool]] = None,
|
tools_to_call_from: Optional[List[Tool]] = None,
|
||||||
) -> ChatMessage:
|
) -> ChatMessage:
|
||||||
messages = get_clean_message_list(
|
messages = get_clean_message_list(messages, role_conversions=tool_role_conversions)
|
||||||
messages, role_conversions=tool_role_conversions
|
|
||||||
)
|
|
||||||
if tools_to_call_from is not None:
|
if tools_to_call_from is not None:
|
||||||
prompt_tensor = self.tokenizer.apply_chat_template(
|
prompt_tensor = self.tokenizer.apply_chat_template(
|
||||||
messages,
|
messages,
|
||||||
|
@ -448,9 +425,7 @@ class TransformersModel(Model):
|
||||||
|
|
||||||
out = self.model.generate(
|
out = self.model.generate(
|
||||||
**prompt_tensor,
|
**prompt_tensor,
|
||||||
stopping_criteria=(
|
stopping_criteria=(self.make_stopping_criteria(stop_sequences) if stop_sequences else None),
|
||||||
self.make_stopping_criteria(stop_sequences) if stop_sequences else None
|
|
||||||
),
|
|
||||||
**self.kwargs,
|
**self.kwargs,
|
||||||
)
|
)
|
||||||
generated_tokens = out[0, count_prompt_tokens:]
|
generated_tokens = out[0, count_prompt_tokens:]
|
||||||
|
@ -475,9 +450,7 @@ class TransformersModel(Model):
|
||||||
ChatMessageToolCall(
|
ChatMessageToolCall(
|
||||||
id="".join(random.choices("0123456789", k=5)),
|
id="".join(random.choices("0123456789", k=5)),
|
||||||
type="function",
|
type="function",
|
||||||
function=ChatMessageToolCallDefinition(
|
function=ChatMessageToolCallDefinition(name=tool_name, arguments=tool_arguments),
|
||||||
name=tool_name, arguments=tool_arguments
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -525,9 +498,7 @@ class LiteLLMModel(Model):
|
||||||
grammar: Optional[str] = None,
|
grammar: Optional[str] = None,
|
||||||
tools_to_call_from: Optional[List[Tool]] = None,
|
tools_to_call_from: Optional[List[Tool]] = None,
|
||||||
) -> ChatMessage:
|
) -> ChatMessage:
|
||||||
messages = get_clean_message_list(
|
messages = get_clean_message_list(messages, role_conversions=tool_role_conversions)
|
||||||
messages, role_conversions=tool_role_conversions
|
|
||||||
)
|
|
||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
if tools_to_call_from:
|
if tools_to_call_from:
|
||||||
|
@ -604,11 +575,7 @@ class OpenAIServerModel(Model):
|
||||||
) -> ChatMessage:
|
) -> ChatMessage:
|
||||||
messages = get_clean_message_list(
|
messages = get_clean_message_list(
|
||||||
messages,
|
messages,
|
||||||
role_conversions=(
|
role_conversions=(self.custom_role_conversions if self.custom_role_conversions else tool_role_conversions),
|
||||||
self.custom_role_conversions
|
|
||||||
if self.custom_role_conversions
|
|
||||||
else tool_role_conversions
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
if tools_to_call_from:
|
if tools_to_call_from:
|
||||||
response = self.client.chat.completions.create(
|
response = self.client.chat.completions.create(
|
||||||
|
|
|
@ -22,10 +22,7 @@ class Monitor:
|
||||||
self.step_durations = []
|
self.step_durations = []
|
||||||
self.tracked_model = tracked_model
|
self.tracked_model = tracked_model
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
if (
|
if getattr(self.tracked_model, "last_input_token_count", "Not found") != "Not found":
|
||||||
getattr(self.tracked_model, "last_input_token_count", "Not found")
|
|
||||||
!= "Not found"
|
|
||||||
):
|
|
||||||
self.total_input_token_count = 0
|
self.total_input_token_count = 0
|
||||||
self.total_output_token_count = 0
|
self.total_output_token_count = 0
|
||||||
|
|
||||||
|
@ -48,7 +45,9 @@ class Monitor:
|
||||||
if getattr(self.tracked_model, "last_input_token_count", None) is not None:
|
if getattr(self.tracked_model, "last_input_token_count", None) is not None:
|
||||||
self.total_input_token_count += self.tracked_model.last_input_token_count
|
self.total_input_token_count += self.tracked_model.last_input_token_count
|
||||||
self.total_output_token_count += self.tracked_model.last_output_token_count
|
self.total_output_token_count += self.tracked_model.last_output_token_count
|
||||||
console_outputs += f"| Input tokens: {self.total_input_token_count:,} | Output tokens: {self.total_output_token_count:,}"
|
console_outputs += (
|
||||||
|
f"| Input tokens: {self.total_input_token_count:,} | Output tokens: {self.total_output_token_count:,}"
|
||||||
|
)
|
||||||
console_outputs += "]"
|
console_outputs += "]"
|
||||||
self.logger.log(Text(console_outputs, style="dim"), level=1)
|
self.logger.log(Text(console_outputs, style="dim"), level=1)
|
||||||
|
|
||||||
|
|
|
@ -6,6 +6,7 @@ from typing import Set
|
||||||
|
|
||||||
from .utils import BASE_BUILTIN_MODULES
|
from .utils import BASE_BUILTIN_MODULES
|
||||||
|
|
||||||
|
|
||||||
_BUILTIN_NAMES = set(vars(builtins))
|
_BUILTIN_NAMES = set(vars(builtins))
|
||||||
|
|
||||||
|
|
||||||
|
@ -141,9 +142,7 @@ def validate_tool_attributes(cls, check_imports: bool = True) -> None:
|
||||||
# Check that __init__ method takes no arguments
|
# Check that __init__ method takes no arguments
|
||||||
if not cls.__init__.__qualname__ == "Tool.__init__":
|
if not cls.__init__.__qualname__ == "Tool.__init__":
|
||||||
sig = inspect.signature(cls.__init__)
|
sig = inspect.signature(cls.__init__)
|
||||||
non_self_params = list(
|
non_self_params = list([arg_name for arg_name in sig.parameters.keys() if arg_name != "self"])
|
||||||
[arg_name for arg_name in sig.parameters.keys() if arg_name != "self"]
|
|
||||||
)
|
|
||||||
if len(non_self_params) > 0:
|
if len(non_self_params) > 0:
|
||||||
errors.append(
|
errors.append(
|
||||||
f"This tool has additional args specified in __init__(self): {non_self_params}. Make sure it does not, all values should be hardcoded!"
|
f"This tool has additional args specified in __init__(self): {non_self_params}. Make sure it does not, all values should be hardcoded!"
|
||||||
|
@ -174,9 +173,7 @@ def validate_tool_attributes(cls, check_imports: bool = True) -> None:
|
||||||
|
|
||||||
# Check if the assignment is more complex than simple literals
|
# Check if the assignment is more complex than simple literals
|
||||||
if not all(
|
if not all(
|
||||||
isinstance(
|
isinstance(val, (ast.Str, ast.Num, ast.Constant, ast.Dict, ast.List, ast.Set))
|
||||||
val, (ast.Str, ast.Num, ast.Constant, ast.Dict, ast.List, ast.Set)
|
|
||||||
)
|
|
||||||
for val in ast.walk(node.value)
|
for val in ast.walk(node.value)
|
||||||
):
|
):
|
||||||
for target in node.targets:
|
for target in node.targets:
|
||||||
|
@ -195,9 +192,7 @@ def validate_tool_attributes(cls, check_imports: bool = True) -> None:
|
||||||
# Run checks on all methods
|
# Run checks on all methods
|
||||||
for node in class_node.body:
|
for node in class_node.body:
|
||||||
if isinstance(node, ast.FunctionDef):
|
if isinstance(node, ast.FunctionDef):
|
||||||
method_checker = MethodChecker(
|
method_checker = MethodChecker(class_level_checker.class_attributes, check_imports=check_imports)
|
||||||
class_level_checker.class_attributes, check_imports=check_imports
|
|
||||||
)
|
|
||||||
method_checker.visit(node)
|
method_checker.visit(node)
|
||||||
errors += [f"- {node.name}: {error}" for error in method_checker.errors]
|
errors += [f"- {node.name}: {error}" for error in method_checker.errors]
|
||||||
|
|
||||||
|
|
|
@ -36,7 +36,6 @@ from huggingface_hub import (
|
||||||
upload_folder,
|
upload_folder,
|
||||||
)
|
)
|
||||||
from huggingface_hub.utils import RepositoryNotFoundError
|
from huggingface_hub.utils import RepositoryNotFoundError
|
||||||
|
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from transformers.dynamic_module_utils import get_imports
|
from transformers.dynamic_module_utils import get_imports
|
||||||
from transformers.utils import (
|
from transformers.utils import (
|
||||||
|
@ -52,6 +51,7 @@ from .tool_validation import MethodChecker, validate_tool_attributes
|
||||||
from .types import ImageType, handle_agent_input_types, handle_agent_output_types
|
from .types import ImageType, handle_agent_input_types, handle_agent_output_types
|
||||||
from .utils import instance_to_source
|
from .utils import instance_to_source
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
if is_accelerate_available():
|
if is_accelerate_available():
|
||||||
|
@ -77,9 +77,7 @@ def get_repo_type(repo_id, repo_type=None, **hub_kwargs):
|
||||||
hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="model", **hub_kwargs)
|
hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="model", **hub_kwargs)
|
||||||
return "model"
|
return "model"
|
||||||
except RepositoryNotFoundError:
|
except RepositoryNotFoundError:
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(f"`{repo_id}` does not seem to be a valid repo identifier on the Hub.")
|
||||||
f"`{repo_id}` does not seem to be a valid repo identifier on the Hub."
|
|
||||||
)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
return "model"
|
return "model"
|
||||||
except Exception:
|
except Exception:
|
||||||
|
@ -109,9 +107,7 @@ def _convert_type_hints_to_json_schema(func: Callable) -> Dict:
|
||||||
properties[param_name]["nullable"] = True
|
properties[param_name]["nullable"] = True
|
||||||
for param_name in signature.parameters.keys():
|
for param_name in signature.parameters.keys():
|
||||||
if signature.parameters[param_name].default != inspect.Parameter.empty:
|
if signature.parameters[param_name].default != inspect.Parameter.empty:
|
||||||
if (
|
if param_name not in properties: # this can happen if the param has no type hint but a default value
|
||||||
param_name not in properties
|
|
||||||
): # this can happen if the param has no type hint but a default value
|
|
||||||
properties[param_name] = {"nullable": True}
|
properties[param_name] = {"nullable": True}
|
||||||
return properties
|
return properties
|
||||||
|
|
||||||
|
@ -181,9 +177,7 @@ class Tool:
|
||||||
f"Attribute {attr} should have type {expected_type.__name__}, got {type(attr_value)} instead."
|
f"Attribute {attr} should have type {expected_type.__name__}, got {type(attr_value)} instead."
|
||||||
)
|
)
|
||||||
for input_name, input_content in self.inputs.items():
|
for input_name, input_content in self.inputs.items():
|
||||||
assert isinstance(input_content, dict), (
|
assert isinstance(input_content, dict), f"Input '{input_name}' should be a dictionary."
|
||||||
f"Input '{input_name}' should be a dictionary."
|
|
||||||
)
|
|
||||||
assert "type" in input_content and "description" in input_content, (
|
assert "type" in input_content and "description" in input_content, (
|
||||||
f"Input '{input_name}' should have keys 'type' and 'description', has only {list(input_content.keys())}."
|
f"Input '{input_name}' should have keys 'type' and 'description', has only {list(input_content.keys())}."
|
||||||
)
|
)
|
||||||
|
@ -348,15 +342,7 @@ class Tool:
|
||||||
imports = []
|
imports = []
|
||||||
for module in [tool_file]:
|
for module in [tool_file]:
|
||||||
imports.extend(get_imports(module))
|
imports.extend(get_imports(module))
|
||||||
imports = list(
|
imports = list(set([el for el in imports + ["smolagents"] if el not in sys.stdlib_module_names]))
|
||||||
set(
|
|
||||||
[
|
|
||||||
el
|
|
||||||
for el in imports + ["smolagents"]
|
|
||||||
if el not in sys.stdlib_module_names
|
|
||||||
]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
with open(requirements_file, "w", encoding="utf-8") as f:
|
with open(requirements_file, "w", encoding="utf-8") as f:
|
||||||
f.write("\n".join(imports) + "\n")
|
f.write("\n".join(imports) + "\n")
|
||||||
|
|
||||||
|
@ -410,9 +396,7 @@ class Tool:
|
||||||
print(work_dir)
|
print(work_dir)
|
||||||
with open(work_dir + "/tool.py", "r") as f:
|
with open(work_dir + "/tool.py", "r") as f:
|
||||||
print("\n".join(f.readlines()))
|
print("\n".join(f.readlines()))
|
||||||
logger.info(
|
logger.info(f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}")
|
||||||
f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}"
|
|
||||||
)
|
|
||||||
return upload_folder(
|
return upload_folder(
|
||||||
repo_id=repo_id,
|
repo_id=repo_id,
|
||||||
commit_message=commit_message,
|
commit_message=commit_message,
|
||||||
|
@ -592,9 +576,7 @@ class Tool:
|
||||||
self.name = name
|
self.name = name
|
||||||
self.description = description
|
self.description = description
|
||||||
self.client = Client(space_id, hf_token=token)
|
self.client = Client(space_id, hf_token=token)
|
||||||
space_description = self.client.view_api(
|
space_description = self.client.view_api(return_format="dict", print_info=False)["named_endpoints"]
|
||||||
return_format="dict", print_info=False
|
|
||||||
)["named_endpoints"]
|
|
||||||
|
|
||||||
# If api_name is not defined, take the first of the available APIs for this space
|
# If api_name is not defined, take the first of the available APIs for this space
|
||||||
if api_name is None:
|
if api_name is None:
|
||||||
|
@ -607,9 +589,7 @@ class Tool:
|
||||||
try:
|
try:
|
||||||
space_description_api = space_description[api_name]
|
space_description_api = space_description[api_name]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise KeyError(
|
raise KeyError(f"Could not find specified {api_name=} among available api names.")
|
||||||
f"Could not find specified {api_name=} among available api names."
|
|
||||||
)
|
|
||||||
|
|
||||||
self.inputs = {}
|
self.inputs = {}
|
||||||
for parameter in space_description_api["parameters"]:
|
for parameter in space_description_api["parameters"]:
|
||||||
|
@ -683,8 +663,7 @@ class Tool:
|
||||||
self._gradio_tool = _gradio_tool
|
self._gradio_tool = _gradio_tool
|
||||||
func_args = list(inspect.signature(_gradio_tool.run).parameters.items())
|
func_args = list(inspect.signature(_gradio_tool.run).parameters.items())
|
||||||
self.inputs = {
|
self.inputs = {
|
||||||
key: {"type": CONVERSION_DICT[value.annotation], "description": ""}
|
key: {"type": CONVERSION_DICT[value.annotation], "description": ""} for key, value in func_args
|
||||||
for key, value in func_args
|
|
||||||
}
|
}
|
||||||
self.forward = self._gradio_tool.run
|
self.forward = self._gradio_tool.run
|
||||||
|
|
||||||
|
@ -726,9 +705,7 @@ DEFAULT_TOOL_DESCRIPTION_TEMPLATE = """
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def get_tool_description_with_args(
|
def get_tool_description_with_args(tool: Tool, description_template: Optional[str] = None) -> str:
|
||||||
tool: Tool, description_template: Optional[str] = None
|
|
||||||
) -> str:
|
|
||||||
if description_template is None:
|
if description_template is None:
|
||||||
description_template = DEFAULT_TOOL_DESCRIPTION_TEMPLATE
|
description_template = DEFAULT_TOOL_DESCRIPTION_TEMPLATE
|
||||||
compiled_template = compile_jinja_template(description_template)
|
compiled_template = compile_jinja_template(description_template)
|
||||||
|
@ -748,10 +725,7 @@ def compile_jinja_template(template):
|
||||||
raise ImportError("template requires jinja2 to be installed.")
|
raise ImportError("template requires jinja2 to be installed.")
|
||||||
|
|
||||||
if version.parse(jinja2.__version__) < version.parse("3.1.0"):
|
if version.parse(jinja2.__version__) < version.parse("3.1.0"):
|
||||||
raise ImportError(
|
raise ImportError(f"template requires jinja2>=3.1.0 to be installed. Your version is {jinja2.__version__}.")
|
||||||
"template requires jinja2>=3.1.0 to be installed. Your version is "
|
|
||||||
f"{jinja2.__version__}."
|
|
||||||
)
|
|
||||||
|
|
||||||
def raise_exception(message):
|
def raise_exception(message):
|
||||||
raise TemplateError(message)
|
raise TemplateError(message)
|
||||||
|
@ -772,9 +746,7 @@ def launch_gradio_demo(tool: Tool):
|
||||||
try:
|
try:
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError("Gradio should be installed in order to launch a gradio demo.")
|
||||||
"Gradio should be installed in order to launch a gradio demo."
|
|
||||||
)
|
|
||||||
|
|
||||||
TYPE_TO_COMPONENT_CLASS_MAPPING = {
|
TYPE_TO_COMPONENT_CLASS_MAPPING = {
|
||||||
"image": gr.Image,
|
"image": gr.Image,
|
||||||
|
@ -791,9 +763,7 @@ def launch_gradio_demo(tool: Tool):
|
||||||
|
|
||||||
gradio_inputs = []
|
gradio_inputs = []
|
||||||
for input_name, input_details in tool.inputs.items():
|
for input_name, input_details in tool.inputs.items():
|
||||||
input_gradio_component_class = TYPE_TO_COMPONENT_CLASS_MAPPING[
|
input_gradio_component_class = TYPE_TO_COMPONENT_CLASS_MAPPING[input_details["type"]]
|
||||||
input_details["type"]
|
|
||||||
]
|
|
||||||
new_component = input_gradio_component_class(label=input_name)
|
new_component = input_gradio_component_class(label=input_name)
|
||||||
gradio_inputs.append(new_component)
|
gradio_inputs.append(new_component)
|
||||||
|
|
||||||
|
@ -922,14 +892,9 @@ class ToolCollection:
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
_collection = get_collection(collection_slug, token=token)
|
_collection = get_collection(collection_slug, token=token)
|
||||||
_hub_repo_ids = {
|
_hub_repo_ids = {item.item_id for item in _collection.items if item.item_type == "space"}
|
||||||
item.item_id for item in _collection.items if item.item_type == "space"
|
|
||||||
}
|
|
||||||
|
|
||||||
tools = {
|
tools = {Tool.from_hub(repo_id, token, trust_remote_code) for repo_id in _hub_repo_ids}
|
||||||
Tool.from_hub(repo_id, token, trust_remote_code)
|
|
||||||
for repo_id in _hub_repo_ids
|
|
||||||
}
|
|
||||||
|
|
||||||
return cls(tools)
|
return cls(tools)
|
||||||
|
|
||||||
|
@ -986,9 +951,7 @@ def tool(tool_function: Callable) -> Tool:
|
||||||
"""
|
"""
|
||||||
parameters = get_json_schema(tool_function)["function"]
|
parameters = get_json_schema(tool_function)["function"]
|
||||||
if "return" not in parameters:
|
if "return" not in parameters:
|
||||||
raise TypeHintParsingException(
|
raise TypeHintParsingException("Tool return type not found: make sure your function has a return type hint!")
|
||||||
"Tool return type not found: make sure your function has a return type hint!"
|
|
||||||
)
|
|
||||||
|
|
||||||
class SimpleTool(Tool):
|
class SimpleTool(Tool):
|
||||||
def __init__(self, name, description, inputs, output_type, function):
|
def __init__(self, name, description, inputs, output_type, function):
|
||||||
|
@ -1007,9 +970,9 @@ def tool(tool_function: Callable) -> Tool:
|
||||||
function=tool_function,
|
function=tool_function,
|
||||||
)
|
)
|
||||||
original_signature = inspect.signature(tool_function)
|
original_signature = inspect.signature(tool_function)
|
||||||
new_parameters = [
|
new_parameters = [inspect.Parameter("self", inspect.Parameter.POSITIONAL_ONLY)] + list(
|
||||||
inspect.Parameter("self", inspect.Parameter.POSITIONAL_ONLY)
|
original_signature.parameters.values()
|
||||||
] + list(original_signature.parameters.values())
|
)
|
||||||
new_signature = original_signature.replace(parameters=new_parameters)
|
new_signature = original_signature.replace(parameters=new_parameters)
|
||||||
simple_tool.forward.__signature__ = new_signature
|
simple_tool.forward.__signature__ = new_signature
|
||||||
return simple_tool
|
return simple_tool
|
||||||
|
@ -1082,9 +1045,7 @@ class PipelineTool(Tool):
|
||||||
|
|
||||||
if model is None:
|
if model is None:
|
||||||
if self.default_checkpoint is None:
|
if self.default_checkpoint is None:
|
||||||
raise ValueError(
|
raise ValueError("This tool does not implement a default checkpoint, you need to pass one.")
|
||||||
"This tool does not implement a default checkpoint, you need to pass one."
|
|
||||||
)
|
|
||||||
model = self.default_checkpoint
|
model = self.default_checkpoint
|
||||||
if pre_processor is None:
|
if pre_processor is None:
|
||||||
pre_processor = model
|
pre_processor = model
|
||||||
|
@ -1107,21 +1068,15 @@ class PipelineTool(Tool):
|
||||||
Instantiates the `pre_processor`, `model` and `post_processor` if necessary.
|
Instantiates the `pre_processor`, `model` and `post_processor` if necessary.
|
||||||
"""
|
"""
|
||||||
if isinstance(self.pre_processor, str):
|
if isinstance(self.pre_processor, str):
|
||||||
self.pre_processor = self.pre_processor_class.from_pretrained(
|
self.pre_processor = self.pre_processor_class.from_pretrained(self.pre_processor, **self.hub_kwargs)
|
||||||
self.pre_processor, **self.hub_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(self.model, str):
|
if isinstance(self.model, str):
|
||||||
self.model = self.model_class.from_pretrained(
|
self.model = self.model_class.from_pretrained(self.model, **self.model_kwargs, **self.hub_kwargs)
|
||||||
self.model, **self.model_kwargs, **self.hub_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.post_processor is None:
|
if self.post_processor is None:
|
||||||
self.post_processor = self.pre_processor
|
self.post_processor = self.pre_processor
|
||||||
elif isinstance(self.post_processor, str):
|
elif isinstance(self.post_processor, str):
|
||||||
self.post_processor = self.post_processor_class.from_pretrained(
|
self.post_processor = self.post_processor_class.from_pretrained(self.post_processor, **self.hub_kwargs)
|
||||||
self.post_processor, **self.hub_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.device is None:
|
if self.device is None:
|
||||||
if self.device_map is not None:
|
if self.device_map is not None:
|
||||||
|
@ -1165,12 +1120,8 @@ class PipelineTool(Tool):
|
||||||
|
|
||||||
encoded_inputs = self.encode(*args, **kwargs)
|
encoded_inputs = self.encode(*args, **kwargs)
|
||||||
|
|
||||||
tensor_inputs = {
|
tensor_inputs = {k: v for k, v in encoded_inputs.items() if isinstance(v, torch.Tensor)}
|
||||||
k: v for k, v in encoded_inputs.items() if isinstance(v, torch.Tensor)
|
non_tensor_inputs = {k: v for k, v in encoded_inputs.items() if not isinstance(v, torch.Tensor)}
|
||||||
}
|
|
||||||
non_tensor_inputs = {
|
|
||||||
k: v for k, v in encoded_inputs.items() if not isinstance(v, torch.Tensor)
|
|
||||||
}
|
|
||||||
|
|
||||||
encoded_inputs = send_to_device(tensor_inputs, self.device)
|
encoded_inputs = send_to_device(tensor_inputs, self.device)
|
||||||
outputs = self.forward({**encoded_inputs, **non_tensor_inputs})
|
outputs = self.forward({**encoded_inputs, **non_tensor_inputs})
|
||||||
|
|
|
@ -27,6 +27,7 @@ from transformers.utils import (
|
||||||
is_vision_available,
|
is_vision_available,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
|
@ -113,9 +114,7 @@ class AgentImage(AgentType, ImageType):
|
||||||
elif isinstance(value, np.ndarray):
|
elif isinstance(value, np.ndarray):
|
||||||
self._tensor = torch.from_numpy(value)
|
self._tensor = torch.from_numpy(value)
|
||||||
else:
|
else:
|
||||||
raise TypeError(
|
raise TypeError(f"Unsupported type for {self.__class__.__name__}: {type(value)}")
|
||||||
f"Unsupported type for {self.__class__.__name__}: {type(value)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _ipython_display_(self, include=None, exclude=None):
|
def _ipython_display_(self, include=None, exclude=None):
|
||||||
"""
|
"""
|
||||||
|
@ -264,9 +263,7 @@ if is_torch_available():
|
||||||
|
|
||||||
def handle_agent_input_types(*args, **kwargs):
|
def handle_agent_input_types(*args, **kwargs):
|
||||||
args = [(arg.to_raw() if isinstance(arg, AgentType) else arg) for arg in args]
|
args = [(arg.to_raw() if isinstance(arg, AgentType) else arg) for arg in args]
|
||||||
kwargs = {
|
kwargs = {k: (v.to_raw() if isinstance(v, AgentType) else v) for k, v in kwargs.items()}
|
||||||
k: (v.to_raw() if isinstance(v, AgentType) else v) for k, v in kwargs.items()
|
|
||||||
}
|
|
||||||
return args, kwargs
|
return args, kwargs
|
||||||
|
|
||||||
|
|
||||||
|
@ -279,9 +276,7 @@ def handle_agent_output_types(output, output_type=None):
|
||||||
# If the class does not have defined output, then we map according to the type
|
# If the class does not have defined output, then we map according to the type
|
||||||
for _k, _v in INSTANCE_TYPE_MAPPING.items():
|
for _k, _v in INSTANCE_TYPE_MAPPING.items():
|
||||||
if isinstance(output, _k):
|
if isinstance(output, _k):
|
||||||
if (
|
if _k is not object: # avoid converting to audio if torch is not installed
|
||||||
_k is not object
|
|
||||||
): # avoid converting to audio if torch is not installed
|
|
||||||
return _v(output)
|
return _v(output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
|
@ -83,9 +83,7 @@ def parse_json_blob(json_blob: str) -> Dict[str, str]:
|
||||||
try:
|
try:
|
||||||
first_accolade_index = json_blob.find("{")
|
first_accolade_index = json_blob.find("{")
|
||||||
last_accolade_index = [a.start() for a in list(re.finditer("}", json_blob))][-1]
|
last_accolade_index = [a.start() for a in list(re.finditer("}", json_blob))][-1]
|
||||||
json_blob = json_blob[first_accolade_index : last_accolade_index + 1].replace(
|
json_blob = json_blob[first_accolade_index : last_accolade_index + 1].replace('\\"', "'")
|
||||||
'\\"', "'"
|
|
||||||
)
|
|
||||||
json_data = json.loads(json_blob, strict=False)
|
json_data = json.loads(json_blob, strict=False)
|
||||||
return json_data
|
return json_data
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
|
@ -162,9 +160,7 @@ def parse_json_tool_call(json_blob: str) -> Tuple[str, Union[str, None]]:
|
||||||
MAX_LENGTH_TRUNCATE_CONTENT = 20000
|
MAX_LENGTH_TRUNCATE_CONTENT = 20000
|
||||||
|
|
||||||
|
|
||||||
def truncate_content(
|
def truncate_content(content: str, max_length: int = MAX_LENGTH_TRUNCATE_CONTENT) -> str:
|
||||||
content: str, max_length: int = MAX_LENGTH_TRUNCATE_CONTENT
|
|
||||||
) -> str:
|
|
||||||
if len(content) <= max_length:
|
if len(content) <= max_length:
|
||||||
return content
|
return content
|
||||||
else:
|
else:
|
||||||
|
@ -206,12 +202,8 @@ def is_same_method(method1, method2):
|
||||||
source2 = get_method_source(method2)
|
source2 = get_method_source(method2)
|
||||||
|
|
||||||
# Remove method decorators if any
|
# Remove method decorators if any
|
||||||
source1 = "\n".join(
|
source1 = "\n".join(line for line in source1.split("\n") if not line.strip().startswith("@"))
|
||||||
line for line in source1.split("\n") if not line.strip().startswith("@")
|
source2 = "\n".join(line for line in source2.split("\n") if not line.strip().startswith("@"))
|
||||||
)
|
|
||||||
source2 = "\n".join(
|
|
||||||
line for line in source2.split("\n") if not line.strip().startswith("@")
|
|
||||||
)
|
|
||||||
|
|
||||||
return source1 == source2
|
return source1 == source2
|
||||||
except (TypeError, OSError):
|
except (TypeError, OSError):
|
||||||
|
@ -248,9 +240,7 @@ def instance_to_source(instance, base_cls=None):
|
||||||
for name, value in cls.__dict__.items()
|
for name, value in cls.__dict__.items()
|
||||||
if not name.startswith("__")
|
if not name.startswith("__")
|
||||||
and not callable(value)
|
and not callable(value)
|
||||||
and not (
|
and not (base_cls and hasattr(base_cls, name) and getattr(base_cls, name) == value)
|
||||||
base_cls and hasattr(base_cls, name) and getattr(base_cls, name) == value
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for name, value in class_attrs.items():
|
for name, value in class_attrs.items():
|
||||||
|
@ -271,9 +261,7 @@ def instance_to_source(instance, base_cls=None):
|
||||||
for name, func in cls.__dict__.items()
|
for name, func in cls.__dict__.items()
|
||||||
if callable(func)
|
if callable(func)
|
||||||
and not (
|
and not (
|
||||||
base_cls
|
base_cls and hasattr(base_cls, name) and getattr(base_cls, name).__code__.co_code == func.__code__.co_code
|
||||||
and hasattr(base_cls, name)
|
|
||||||
and getattr(base_cls, name).__code__.co_code == func.__code__.co_code
|
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -284,9 +272,7 @@ def instance_to_source(instance, base_cls=None):
|
||||||
first_line = method_lines[0]
|
first_line = method_lines[0]
|
||||||
indent = len(first_line) - len(first_line.lstrip())
|
indent = len(first_line) - len(first_line.lstrip())
|
||||||
method_lines = [line[indent:] for line in method_lines]
|
method_lines = [line[indent:] for line in method_lines]
|
||||||
method_source = "\n".join(
|
method_source = "\n".join([" " + line if line.strip() else line for line in method_lines])
|
||||||
[" " + line if line.strip() else line for line in method_lines]
|
|
||||||
)
|
|
||||||
class_lines.append(method_source)
|
class_lines.append(method_source)
|
||||||
class_lines.append("")
|
class_lines.append("")
|
||||||
|
|
||||||
|
|
|
@ -28,13 +28,13 @@ from smolagents.agents import (
|
||||||
ToolCallingAgent,
|
ToolCallingAgent,
|
||||||
)
|
)
|
||||||
from smolagents.default_tools import PythonInterpreterTool
|
from smolagents.default_tools import PythonInterpreterTool
|
||||||
from smolagents.tools import tool
|
|
||||||
from smolagents.types import AgentImage, AgentText
|
|
||||||
from smolagents.models import (
|
from smolagents.models import (
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
ChatMessageToolCall,
|
ChatMessageToolCall,
|
||||||
ChatMessageToolCallDefinition,
|
ChatMessageToolCallDefinition,
|
||||||
)
|
)
|
||||||
|
from smolagents.tools import tool
|
||||||
|
from smolagents.types import AgentImage, AgentText
|
||||||
from smolagents.utils import BASE_BUILTIN_MODULES
|
from smolagents.utils import BASE_BUILTIN_MODULES
|
||||||
|
|
||||||
|
|
||||||
|
@ -44,9 +44,7 @@ def get_new_path(suffix="") -> str:
|
||||||
|
|
||||||
|
|
||||||
class FakeToolCallModel:
|
class FakeToolCallModel:
|
||||||
def __call__(
|
def __call__(self, messages, tools_to_call_from=None, stop_sequences=None, grammar=None):
|
||||||
self, messages, tools_to_call_from=None, stop_sequences=None, grammar=None
|
|
||||||
):
|
|
||||||
if len(messages) < 3:
|
if len(messages) < 3:
|
||||||
return ChatMessage(
|
return ChatMessage(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
|
@ -69,18 +67,14 @@ class FakeToolCallModel:
|
||||||
ChatMessageToolCall(
|
ChatMessageToolCall(
|
||||||
id="call_1",
|
id="call_1",
|
||||||
type="function",
|
type="function",
|
||||||
function=ChatMessageToolCallDefinition(
|
function=ChatMessageToolCallDefinition(name="final_answer", arguments={"answer": "7.2904"}),
|
||||||
name="final_answer", arguments={"answer": "7.2904"}
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class FakeToolCallModelImage:
|
class FakeToolCallModelImage:
|
||||||
def __call__(
|
def __call__(self, messages, tools_to_call_from=None, stop_sequences=None, grammar=None):
|
||||||
self, messages, tools_to_call_from=None, stop_sequences=None, grammar=None
|
|
||||||
):
|
|
||||||
if len(messages) < 3:
|
if len(messages) < 3:
|
||||||
return ChatMessage(
|
return ChatMessage(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
|
@ -104,9 +98,7 @@ class FakeToolCallModelImage:
|
||||||
ChatMessageToolCall(
|
ChatMessageToolCall(
|
||||||
id="call_1",
|
id="call_1",
|
||||||
type="function",
|
type="function",
|
||||||
function=ChatMessageToolCallDefinition(
|
function=ChatMessageToolCallDefinition(name="final_answer", arguments="image.png"),
|
||||||
name="final_answer", arguments="image.png"
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -271,17 +263,13 @@ print(result)
|
||||||
|
|
||||||
class AgentTests(unittest.TestCase):
|
class AgentTests(unittest.TestCase):
|
||||||
def test_fake_single_step_code_agent(self):
|
def test_fake_single_step_code_agent(self):
|
||||||
agent = CodeAgent(
|
agent = CodeAgent(tools=[PythonInterpreterTool()], model=fake_code_model_single_step)
|
||||||
tools=[PythonInterpreterTool()], model=fake_code_model_single_step
|
|
||||||
)
|
|
||||||
output = agent.run("What is 2 multiplied by 3.6452?", single_step=True)
|
output = agent.run("What is 2 multiplied by 3.6452?", single_step=True)
|
||||||
assert isinstance(output, str)
|
assert isinstance(output, str)
|
||||||
assert "7.2904" in output
|
assert "7.2904" in output
|
||||||
|
|
||||||
def test_fake_toolcalling_agent(self):
|
def test_fake_toolcalling_agent(self):
|
||||||
agent = ToolCallingAgent(
|
agent = ToolCallingAgent(tools=[PythonInterpreterTool()], model=FakeToolCallModel())
|
||||||
tools=[PythonInterpreterTool()], model=FakeToolCallModel()
|
|
||||||
)
|
|
||||||
output = agent.run("What is 2 multiplied by 3.6452?")
|
output = agent.run("What is 2 multiplied by 3.6452?")
|
||||||
assert isinstance(output, str)
|
assert isinstance(output, str)
|
||||||
assert "7.2904" in output
|
assert "7.2904" in output
|
||||||
|
@ -301,9 +289,7 @@ class AgentTests(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
return Image.open(Path(get_tests_dir("fixtures")) / "000000039769.png")
|
return Image.open(Path(get_tests_dir("fixtures")) / "000000039769.png")
|
||||||
|
|
||||||
agent = ToolCallingAgent(
|
agent = ToolCallingAgent(tools=[fake_image_generation_tool], model=FakeToolCallModelImage())
|
||||||
tools=[fake_image_generation_tool], model=FakeToolCallModelImage()
|
|
||||||
)
|
|
||||||
output = agent.run("Make me an image.")
|
output = agent.run("Make me an image.")
|
||||||
assert isinstance(output, AgentImage)
|
assert isinstance(output, AgentImage)
|
||||||
assert isinstance(agent.state["image.png"], Image.Image)
|
assert isinstance(agent.state["image.png"], Image.Image)
|
||||||
|
@ -315,9 +301,7 @@ class AgentTests(unittest.TestCase):
|
||||||
assert output == 7.2904
|
assert output == 7.2904
|
||||||
assert agent.logs[1].task == "What is 2 multiplied by 3.6452?"
|
assert agent.logs[1].task == "What is 2 multiplied by 3.6452?"
|
||||||
assert agent.logs[3].tool_calls == [
|
assert agent.logs[3].tool_calls == [
|
||||||
ToolCall(
|
ToolCall(name="python_interpreter", arguments="final_answer(7.2904)", id="call_3")
|
||||||
name="python_interpreter", arguments="final_answer(7.2904)", id="call_3"
|
|
||||||
)
|
|
||||||
]
|
]
|
||||||
|
|
||||||
def test_additional_args_added_to_task(self):
|
def test_additional_args_added_to_task(self):
|
||||||
|
@ -351,9 +335,7 @@ class AgentTests(unittest.TestCase):
|
||||||
assert "Code execution failed at line 'print = 2' because of" in str(agent.logs)
|
assert "Code execution failed at line 'print = 2' because of" in str(agent.logs)
|
||||||
|
|
||||||
def test_code_agent_syntax_error_show_offending_lines(self):
|
def test_code_agent_syntax_error_show_offending_lines(self):
|
||||||
agent = CodeAgent(
|
agent = CodeAgent(tools=[PythonInterpreterTool()], model=fake_code_model_syntax_error)
|
||||||
tools=[PythonInterpreterTool()], model=fake_code_model_syntax_error
|
|
||||||
)
|
|
||||||
output = agent.run("What is 2 multiplied by 3.6452?")
|
output = agent.run("What is 2 multiplied by 3.6452?")
|
||||||
assert isinstance(output, AgentText)
|
assert isinstance(output, AgentText)
|
||||||
assert output == "got an error"
|
assert output == "got an error"
|
||||||
|
@ -391,9 +373,7 @@ class AgentTests(unittest.TestCase):
|
||||||
def test_init_agent_with_different_toolsets(self):
|
def test_init_agent_with_different_toolsets(self):
|
||||||
toolset_1 = []
|
toolset_1 = []
|
||||||
agent = CodeAgent(tools=toolset_1, model=fake_code_model)
|
agent = CodeAgent(tools=toolset_1, model=fake_code_model)
|
||||||
assert (
|
assert len(agent.tools) == 1 # when no tools are provided, only the final_answer tool is added by default
|
||||||
len(agent.tools) == 1
|
|
||||||
) # when no tools are provided, only the final_answer tool is added by default
|
|
||||||
|
|
||||||
toolset_2 = [PythonInterpreterTool(), PythonInterpreterTool()]
|
toolset_2 = [PythonInterpreterTool(), PythonInterpreterTool()]
|
||||||
agent = CodeAgent(tools=toolset_2, model=fake_code_model)
|
agent = CodeAgent(tools=toolset_2, model=fake_code_model)
|
||||||
|
@ -436,9 +416,7 @@ class AgentTests(unittest.TestCase):
|
||||||
assert "You can also give requests to team members." not in agent.system_prompt
|
assert "You can also give requests to team members." not in agent.system_prompt
|
||||||
print("ok1")
|
print("ok1")
|
||||||
assert "{{managed_agents_descriptions}}" not in agent.system_prompt
|
assert "{{managed_agents_descriptions}}" not in agent.system_prompt
|
||||||
assert (
|
assert "You can also give requests to team members." in manager_agent.system_prompt
|
||||||
"You can also give requests to team members." in manager_agent.system_prompt
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_code_agent_missing_import_triggers_advice_in_error_log(self):
|
def test_code_agent_missing_import_triggers_advice_in_error_log(self):
|
||||||
agent = CodeAgent(tools=[], model=fake_code_model_import)
|
agent = CodeAgent(tools=[], model=fake_code_model_import)
|
||||||
|
|
|
@ -136,9 +136,7 @@ class TestDocs:
|
||||||
try:
|
try:
|
||||||
code_blocks = [
|
code_blocks = [
|
||||||
(
|
(
|
||||||
block.replace(
|
block.replace("<YOUR_HUGGINGFACEHUB_API_TOKEN>", os.getenv("HF_TOKEN"))
|
||||||
"<YOUR_HUGGINGFACEHUB_API_TOKEN>", os.getenv("HF_TOKEN")
|
|
||||||
)
|
|
||||||
.replace("YOUR_ANTHROPIC_API_KEY", os.getenv("ANTHROPIC_API_KEY"))
|
.replace("YOUR_ANTHROPIC_API_KEY", os.getenv("ANTHROPIC_API_KEY"))
|
||||||
.replace("{your_username}", "m-ric")
|
.replace("{your_username}", "m-ric")
|
||||||
)
|
)
|
||||||
|
@ -150,9 +148,7 @@ class TestDocs:
|
||||||
except SubprocessCallException as e:
|
except SubprocessCallException as e:
|
||||||
pytest.fail(f"\nError while testing {doc_path.name}:\n{str(e)}")
|
pytest.fail(f"\nError while testing {doc_path.name}:\n{str(e)}")
|
||||||
except Exception:
|
except Exception:
|
||||||
pytest.fail(
|
pytest.fail(f"\nUnexpected error while testing {doc_path.name}:\n{traceback.format_exc()}")
|
||||||
f"\nUnexpected error while testing {doc_path.name}:\n{traceback.format_exc()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def _setup(self):
|
def _setup(self):
|
||||||
|
@ -174,6 +170,4 @@ def pytest_generate_tests(metafunc):
|
||||||
test_class.setup_class()
|
test_class.setup_class()
|
||||||
|
|
||||||
# Parameterize with the markdown files
|
# Parameterize with the markdown files
|
||||||
metafunc.parametrize(
|
metafunc.parametrize("doc_path", test_class.md_files, ids=[f.stem for f in test_class.md_files])
|
||||||
"doc_path", test_class.md_files, ids=[f.stem for f in test_class.md_files]
|
|
||||||
)
|
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from smolagents.default_tools import PythonInterpreterTool, VisitWebpageTool
|
from smolagents.default_tools import PythonInterpreterTool, VisitWebpageTool
|
||||||
|
@ -23,14 +24,10 @@ from .test_tools import ToolTesterMixin
|
||||||
|
|
||||||
class DefaultToolTests(unittest.TestCase):
|
class DefaultToolTests(unittest.TestCase):
|
||||||
def test_visit_webpage(self):
|
def test_visit_webpage(self):
|
||||||
arguments = {
|
arguments = {"url": "https://en.wikipedia.org/wiki/United_States_Secretary_of_Homeland_Security"}
|
||||||
"url": "https://en.wikipedia.org/wiki/United_States_Secretary_of_Homeland_Security"
|
|
||||||
}
|
|
||||||
result = VisitWebpageTool()(arguments)
|
result = VisitWebpageTool()(arguments)
|
||||||
assert isinstance(result, str)
|
assert isinstance(result, str)
|
||||||
assert (
|
assert "* [About Wikipedia](/wiki/Wikipedia:About)" in result # Proper wikipedia pages have an About
|
||||||
"* [About Wikipedia](/wiki/Wikipedia:About)" in result
|
|
||||||
) # Proper wikipedia pages have an About
|
|
||||||
|
|
||||||
|
|
||||||
class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin):
|
class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin):
|
||||||
|
@ -59,12 +56,7 @@ class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin):
|
||||||
for _input, expected_input in zip(inputs, self.tool.inputs.values()):
|
for _input, expected_input in zip(inputs, self.tool.inputs.values()):
|
||||||
input_type = expected_input["type"]
|
input_type = expected_input["type"]
|
||||||
if isinstance(input_type, list):
|
if isinstance(input_type, list):
|
||||||
_inputs.append(
|
_inputs.append([AGENT_TYPE_MAPPING[_input_type](_input) for _input_type in input_type])
|
||||||
[
|
|
||||||
AGENT_TYPE_MAPPING[_input_type](_input)
|
|
||||||
for _input_type in input_type
|
|
||||||
]
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
_inputs.append(AGENT_TYPE_MAPPING[input_type](_input))
|
_inputs.append(AGENT_TYPE_MAPPING[input_type](_input))
|
||||||
|
|
||||||
|
|
|
@ -26,6 +26,7 @@ from smolagents.types import AGENT_TYPE_MAPPING
|
||||||
|
|
||||||
from .test_tools import ToolTesterMixin
|
from .test_tools import ToolTesterMixin
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
@ -45,11 +46,7 @@ class FinalAnswerToolTester(unittest.TestCase, ToolTesterMixin):
|
||||||
|
|
||||||
def create_inputs(self):
|
def create_inputs(self):
|
||||||
inputs_text = {"answer": "Text input"}
|
inputs_text = {"answer": "Text input"}
|
||||||
inputs_image = {
|
inputs_image = {"answer": Image.open(Path(get_tests_dir("fixtures")) / "000000039769.png").resize((512, 512))}
|
||||||
"answer": Image.open(
|
|
||||||
Path(get_tests_dir("fixtures")) / "000000039769.png"
|
|
||||||
).resize((512, 512))
|
|
||||||
}
|
|
||||||
inputs_audio = {"answer": torch.Tensor(np.ones(3000))}
|
inputs_audio = {"answer": torch.Tensor(np.ones(3000))}
|
||||||
return {"string": inputs_text, "image": inputs_image, "audio": inputs_audio}
|
return {"string": inputs_text, "image": inputs_image, "audio": inputs_audio}
|
||||||
|
|
||||||
|
|
|
@ -12,11 +12,11 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import unittest
|
|
||||||
import json
|
import json
|
||||||
|
import unittest
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from smolagents import models, tool, ChatMessage, HfApiModel, TransformersModel
|
from smolagents import ChatMessage, HfApiModel, TransformersModel, models, tool
|
||||||
|
|
||||||
|
|
||||||
class ModelTests(unittest.TestCase):
|
class ModelTests(unittest.TestCase):
|
||||||
|
@ -33,12 +33,7 @@ class ModelTests(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
|
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
|
||||||
|
|
||||||
assert (
|
assert "nullable" in models.get_json_schema(get_weather)["function"]["parameters"]["properties"]["celsius"]
|
||||||
"nullable"
|
|
||||||
in models.get_json_schema(get_weather)["function"]["parameters"][
|
|
||||||
"properties"
|
|
||||||
]["celsius"]
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_chatmessage_has_model_dumps_json(self):
|
def test_chatmessage_has_model_dumps_json(self):
|
||||||
message = ChatMessage("user", "Hello!")
|
message = ChatMessage("user", "Hello!")
|
||||||
|
|
|
@ -43,9 +43,7 @@ class FakeLLMModel:
|
||||||
ChatMessageToolCall(
|
ChatMessageToolCall(
|
||||||
id="fake_id",
|
id="fake_id",
|
||||||
type="function",
|
type="function",
|
||||||
function=ChatMessageToolCallDefinition(
|
function=ChatMessageToolCallDefinition(name="final_answer", arguments={"answer": "image"}),
|
||||||
name="final_answer", arguments={"answer": "image"}
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -122,9 +120,7 @@ class MonitoringTester(unittest.TestCase):
|
||||||
)
|
)
|
||||||
agent.run("Fake task")
|
agent.run("Fake task")
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(agent.monitor.total_input_token_count, 20) # Should have done two monitoring callbacks
|
||||||
agent.monitor.total_input_token_count, 20
|
|
||||||
) # Should have done two monitoring callbacks
|
|
||||||
self.assertEqual(agent.monitor.total_output_token_count, 0)
|
self.assertEqual(agent.monitor.total_output_token_count, 0)
|
||||||
|
|
||||||
def test_streaming_agent_text_output(self):
|
def test_streaming_agent_text_output(self):
|
||||||
|
|
|
@ -55,10 +55,7 @@ class PythonInterpreterTester(unittest.TestCase):
|
||||||
code = "print = '3'"
|
code = "print = '3'"
|
||||||
with pytest.raises(InterpreterError) as e:
|
with pytest.raises(InterpreterError) as e:
|
||||||
evaluate_python_code(code, {"print": print}, state={})
|
evaluate_python_code(code, {"print": print}, state={})
|
||||||
assert (
|
assert "Cannot assign to name 'print': doing this would erase the existing tool!" in str(e)
|
||||||
"Cannot assign to name 'print': doing this would erase the existing tool!"
|
|
||||||
in str(e)
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_subscript_call(self):
|
def test_subscript_call(self):
|
||||||
code = """def foo(x,y):return x*y\n\ndef boo(y):\n\treturn y**3\nfun = [foo, boo]\nresult_foo = fun[0](4,2)\nresult_boo = fun[1](4)"""
|
code = """def foo(x,y):return x*y\n\ndef boo(y):\n\treturn y**3\nfun = [foo, boo]\nresult_foo = fun[0](4,2)\nresult_boo = fun[1](4)"""
|
||||||
|
@ -92,9 +89,7 @@ class PythonInterpreterTester(unittest.TestCase):
|
||||||
state = {"x": 3}
|
state = {"x": 3}
|
||||||
result, _ = evaluate_python_code(code, {"add_two": add_two}, state=state)
|
result, _ = evaluate_python_code(code, {"add_two": add_two}, state=state)
|
||||||
self.assertDictEqual(result, {"x": 3, "y": 5})
|
self.assertDictEqual(result, {"x": 3, "y": 5})
|
||||||
self.assertDictEqual(
|
self.assertDictEqual(state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""})
|
||||||
state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""}
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_evaluate_expression(self):
|
def test_evaluate_expression(self):
|
||||||
code = "x = 3\ny = 5"
|
code = "x = 3\ny = 5"
|
||||||
|
@ -110,9 +105,7 @@ class PythonInterpreterTester(unittest.TestCase):
|
||||||
result, _ = evaluate_python_code(code, {}, state=state)
|
result, _ = evaluate_python_code(code, {}, state=state)
|
||||||
# evaluate returns the value of the last assignment.
|
# evaluate returns the value of the last assignment.
|
||||||
assert result == "This is x: 3."
|
assert result == "This is x: 3."
|
||||||
self.assertDictEqual(
|
self.assertDictEqual(state, {"x": 3, "text": "This is x: 3.", "print_outputs": ""})
|
||||||
state, {"x": 3, "text": "This is x: 3.", "print_outputs": ""}
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_evaluate_if(self):
|
def test_evaluate_if(self):
|
||||||
code = "if x <= 3:\n y = 2\nelse:\n y = 5"
|
code = "if x <= 3:\n y = 2\nelse:\n y = 5"
|
||||||
|
@ -153,15 +146,11 @@ class PythonInterpreterTester(unittest.TestCase):
|
||||||
state = {"x": 3}
|
state = {"x": 3}
|
||||||
result, _ = evaluate_python_code(code, {"add_two": add_two}, state=state)
|
result, _ = evaluate_python_code(code, {"add_two": add_two}, state=state)
|
||||||
assert result == 5
|
assert result == 5
|
||||||
self.assertDictEqual(
|
self.assertDictEqual(state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""})
|
||||||
state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""}
|
|
||||||
)
|
|
||||||
|
|
||||||
code = "vendor = {'revenue': 31000, 'rent': 50312}; vendor['ratio'] = round(vendor['revenue'] / vendor['rent'], 2)"
|
code = "vendor = {'revenue': 31000, 'rent': 50312}; vendor['ratio'] = round(vendor['revenue'] / vendor['rent'], 2)"
|
||||||
state = {}
|
state = {}
|
||||||
evaluate_python_code(
|
evaluate_python_code(code, {"min": min, "print": print, "round": round}, state=state)
|
||||||
code, {"min": min, "print": print, "round": round}, state=state
|
|
||||||
)
|
|
||||||
assert state["vendor"] == {"revenue": 31000, "rent": 50312, "ratio": 0.62}
|
assert state["vendor"] == {"revenue": 31000, "rent": 50312, "ratio": 0.62}
|
||||||
|
|
||||||
def test_subscript_string_with_string_index_raises_appropriate_error(self):
|
def test_subscript_string_with_string_index_raises_appropriate_error(self):
|
||||||
|
@ -317,9 +306,7 @@ print(check_digits)
|
||||||
assert result == {0: 0, 1: 1, 2: 4}
|
assert result == {0: 0, 1: 1, 2: 4}
|
||||||
|
|
||||||
code = "{num: name for num, name in {101: 'a', 102: 'b'}.items() if name not in ['a']}"
|
code = "{num: name for num, name in {101: 'a', 102: 'b'}.items() if name not in ['a']}"
|
||||||
result, _ = evaluate_python_code(
|
result, _ = evaluate_python_code(code, {"print": print}, state={}, authorized_imports=["pandas"])
|
||||||
code, {"print": print}, state={}, authorized_imports=["pandas"]
|
|
||||||
)
|
|
||||||
assert result == {102: "b"}
|
assert result == {102: "b"}
|
||||||
|
|
||||||
code = """
|
code = """
|
||||||
|
@ -367,9 +354,7 @@ else:
|
||||||
best_city = "Manhattan"
|
best_city = "Manhattan"
|
||||||
best_city
|
best_city
|
||||||
"""
|
"""
|
||||||
result, _ = evaluate_python_code(
|
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5})
|
||||||
code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}
|
|
||||||
)
|
|
||||||
assert result == "Brooklyn"
|
assert result == "Brooklyn"
|
||||||
|
|
||||||
code = """if d > e and a < b:
|
code = """if d > e and a < b:
|
||||||
|
@ -380,9 +365,7 @@ else:
|
||||||
best_city = "Manhattan"
|
best_city = "Manhattan"
|
||||||
best_city
|
best_city
|
||||||
"""
|
"""
|
||||||
result, _ = evaluate_python_code(
|
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5})
|
||||||
code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}
|
|
||||||
)
|
|
||||||
assert result == "Sacramento"
|
assert result == "Sacramento"
|
||||||
|
|
||||||
def test_if_conditions(self):
|
def test_if_conditions(self):
|
||||||
|
@ -398,9 +381,7 @@ if char.isalpha():
|
||||||
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||||
assert result == 2.0
|
assert result == 2.0
|
||||||
|
|
||||||
code = (
|
code = "from random import choice, seed\nseed(12)\nchoice(['win', 'lose', 'draw'])"
|
||||||
"from random import choice, seed\nseed(12)\nchoice(['win', 'lose', 'draw'])"
|
|
||||||
)
|
|
||||||
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||||
assert result == "lose"
|
assert result == "lose"
|
||||||
|
|
||||||
|
@ -434,14 +415,10 @@ if char.isalpha():
|
||||||
|
|
||||||
# Test submodules are handled properly, thus not raising error
|
# Test submodules are handled properly, thus not raising error
|
||||||
code = "import numpy.random as rd\nrng = rd.default_rng(12345)\nrng.random()"
|
code = "import numpy.random as rd\nrng = rd.default_rng(12345)\nrng.random()"
|
||||||
result, _ = evaluate_python_code(
|
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"])
|
||||||
code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"]
|
|
||||||
)
|
|
||||||
|
|
||||||
code = "from numpy.random import default_rng as d_rng\nrng = d_rng(12345)\nrng.random()"
|
code = "from numpy.random import default_rng as d_rng\nrng = d_rng(12345)\nrng.random()"
|
||||||
result, _ = evaluate_python_code(
|
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"])
|
||||||
code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"]
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_additional_imports(self):
|
def test_additional_imports(self):
|
||||||
code = "import numpy as np"
|
code = "import numpy as np"
|
||||||
|
@ -613,9 +590,7 @@ except ValueError as e:
|
||||||
def test_types_as_objects(self):
|
def test_types_as_objects(self):
|
||||||
code = "type_a = float(2); type_b = str; type_c = int"
|
code = "type_a = float(2); type_b = str; type_c = int"
|
||||||
state = {}
|
state = {}
|
||||||
result, is_final_answer = evaluate_python_code(
|
result, is_final_answer = evaluate_python_code(code, {"float": float, "str": str, "int": int}, state=state)
|
||||||
code, {"float": float, "str": str, "int": int}, state=state
|
|
||||||
)
|
|
||||||
assert result is int
|
assert result is int
|
||||||
|
|
||||||
def test_tuple_id(self):
|
def test_tuple_id(self):
|
||||||
|
@ -733,9 +708,7 @@ while True:
|
||||||
break
|
break
|
||||||
|
|
||||||
i"""
|
i"""
|
||||||
result, is_final_answer = evaluate_python_code(
|
result, is_final_answer = evaluate_python_code(code, {"print": print, "round": round}, state={})
|
||||||
code, {"print": print, "round": round}, state={}
|
|
||||||
)
|
|
||||||
assert result == 3
|
assert result == 3
|
||||||
assert not is_final_answer
|
assert not is_final_answer
|
||||||
|
|
||||||
|
@ -781,9 +754,7 @@ out = [i for sublist in all_res for i in sublist]
|
||||||
out[:10]
|
out[:10]
|
||||||
"""
|
"""
|
||||||
state = {}
|
state = {}
|
||||||
result, is_final_answer = evaluate_python_code(
|
result, is_final_answer = evaluate_python_code(code, {"print": print, "range": range}, state=state)
|
||||||
code, {"print": print, "range": range}, state=state
|
|
||||||
)
|
|
||||||
assert result == [0, 0, 1, 0, 1, 2, 0, 1, 2, 3]
|
assert result == [0, 0, 1, 0, 1, 2, 0, 1, 2, 3]
|
||||||
|
|
||||||
def test_pandas(self):
|
def test_pandas(self):
|
||||||
|
@ -798,9 +769,7 @@ parts_with_5_set_count = df[df['SetCount'] == 5.0]
|
||||||
parts_with_5_set_count[['Quantity', 'SetCount']].values[1]
|
parts_with_5_set_count[['Quantity', 'SetCount']].values[1]
|
||||||
"""
|
"""
|
||||||
state = {}
|
state = {}
|
||||||
result, _ = evaluate_python_code(
|
result, _ = evaluate_python_code(code, {}, state=state, authorized_imports=["pandas"])
|
||||||
code, {}, state=state, authorized_imports=["pandas"]
|
|
||||||
)
|
|
||||||
assert np.array_equal(result, [-1, 5])
|
assert np.array_equal(result, [-1, 5])
|
||||||
|
|
||||||
code = """
|
code = """
|
||||||
|
@ -811,9 +780,7 @@ df = pd.DataFrame.from_dict({"AtomicNumber": [111, 104, 105], "ok": [0, 1, 2]})
|
||||||
# Filter the DataFrame to get only the rows with outdated atomic numbers
|
# Filter the DataFrame to get only the rows with outdated atomic numbers
|
||||||
filtered_df = df.loc[df['AtomicNumber'].isin([104])]
|
filtered_df = df.loc[df['AtomicNumber'].isin([104])]
|
||||||
"""
|
"""
|
||||||
result, _ = evaluate_python_code(
|
result, _ = evaluate_python_code(code, {"print": print}, state={}, authorized_imports=["pandas"])
|
||||||
code, {"print": print}, state={}, authorized_imports=["pandas"]
|
|
||||||
)
|
|
||||||
assert np.array_equal(result.values[0], [104, 1])
|
assert np.array_equal(result.values[0], [104, 1])
|
||||||
|
|
||||||
# Test groupby
|
# Test groupby
|
||||||
|
@ -825,9 +792,7 @@ data = pd.DataFrame.from_dict([
|
||||||
])
|
])
|
||||||
survival_rate_by_class = data.groupby('Pclass')['Survived'].mean()
|
survival_rate_by_class = data.groupby('Pclass')['Survived'].mean()
|
||||||
"""
|
"""
|
||||||
result, _ = evaluate_python_code(
|
result, _ = evaluate_python_code(code, {}, state={}, authorized_imports=["pandas"])
|
||||||
code, {}, state={}, authorized_imports=["pandas"]
|
|
||||||
)
|
|
||||||
assert result.values[1] == 0.5
|
assert result.values[1] == 0.5
|
||||||
|
|
||||||
# Test loc and iloc
|
# Test loc and iloc
|
||||||
|
@ -839,11 +804,9 @@ data = pd.DataFrame.from_dict([
|
||||||
])
|
])
|
||||||
survival_rate_biased = data.loc[data['Survived']==1]['Survived'].mean()
|
survival_rate_biased = data.loc[data['Survived']==1]['Survived'].mean()
|
||||||
survival_rate_biased = data.loc[data['Survived']==1]['Survived'].mean()
|
survival_rate_biased = data.loc[data['Survived']==1]['Survived'].mean()
|
||||||
survival_rate_sorted = data.sort_values(by='Survived', ascending=False).iloc[0]
|
survival_rate_sorted = data.sort_values(by='Survived', ascending=False).iloc[0]
|
||||||
"""
|
"""
|
||||||
result, _ = evaluate_python_code(
|
result, _ = evaluate_python_code(code, {}, state={}, authorized_imports=["pandas"])
|
||||||
code, {}, state={}, authorized_imports=["pandas"]
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_starred(self):
|
def test_starred(self):
|
||||||
code = """
|
code = """
|
||||||
|
@ -864,9 +827,7 @@ coords_barcelona = (41.3869, 2.1660)
|
||||||
|
|
||||||
distance_geneva_barcelona = haversine(*coords_geneva, *coords_barcelona)
|
distance_geneva_barcelona = haversine(*coords_geneva, *coords_barcelona)
|
||||||
"""
|
"""
|
||||||
result, _ = evaluate_python_code(
|
result, _ = evaluate_python_code(code, {"print": print, "map": map}, state={}, authorized_imports=["math"])
|
||||||
code, {"print": print, "map": map}, state={}, authorized_imports=["math"]
|
|
||||||
)
|
|
||||||
assert round(result, 1) == 622395.4
|
assert round(result, 1) == 622395.4
|
||||||
|
|
||||||
def test_for(self):
|
def test_for(self):
|
||||||
|
|
|
@ -16,7 +16,7 @@ import unittest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from textwrap import dedent
|
from textwrap import dedent
|
||||||
from typing import Dict, Optional, Union
|
from typing import Dict, Optional, Union
|
||||||
from unittest.mock import patch, MagicMock
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import mcp
|
import mcp
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -32,6 +32,7 @@ from smolagents.types import (
|
||||||
AgentText,
|
AgentText,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
@ -48,9 +49,7 @@ def create_inputs(tool_inputs: Dict[str, Dict[Union[str, type], str]]):
|
||||||
if input_type == "string":
|
if input_type == "string":
|
||||||
inputs[input_name] = "Text input"
|
inputs[input_name] = "Text input"
|
||||||
elif input_type == "image":
|
elif input_type == "image":
|
||||||
inputs[input_name] = Image.open(
|
inputs[input_name] = Image.open(Path(get_tests_dir("fixtures")) / "000000039769.png").resize((512, 512))
|
||||||
Path(get_tests_dir("fixtures")) / "000000039769.png"
|
|
||||||
).resize((512, 512))
|
|
||||||
elif input_type == "audio":
|
elif input_type == "audio":
|
||||||
inputs[input_name] = np.ones(3000)
|
inputs[input_name] = np.ones(3000)
|
||||||
else:
|
else:
|
||||||
|
@ -224,9 +223,7 @@ class ToolTests(unittest.TestCase):
|
||||||
class FailTool(Tool):
|
class FailTool(Tool):
|
||||||
name = "specific"
|
name = "specific"
|
||||||
description = "test description"
|
description = "test description"
|
||||||
inputs = {
|
inputs = {"string_input": {"type": "string", "description": "input description"}}
|
||||||
"string_input": {"type": "string", "description": "input description"}
|
|
||||||
}
|
|
||||||
output_type = "string"
|
output_type = "string"
|
||||||
|
|
||||||
def __init__(self, url):
|
def __init__(self, url):
|
||||||
|
@ -248,9 +245,7 @@ class ToolTests(unittest.TestCase):
|
||||||
class FailTool(Tool):
|
class FailTool(Tool):
|
||||||
name = "specific"
|
name = "specific"
|
||||||
description = "test description"
|
description = "test description"
|
||||||
inputs = {
|
inputs = {"string_input": {"type": "string", "description": "input description"}}
|
||||||
"string_input": {"type": "string", "description": "input description"}
|
|
||||||
}
|
|
||||||
output_type = "string"
|
output_type = "string"
|
||||||
|
|
||||||
def useless_method(self):
|
def useless_method(self):
|
||||||
|
@ -269,9 +264,7 @@ class ToolTests(unittest.TestCase):
|
||||||
class SuccessTool(Tool):
|
class SuccessTool(Tool):
|
||||||
name = "specific"
|
name = "specific"
|
||||||
description = "test description"
|
description = "test description"
|
||||||
inputs = {
|
inputs = {"string_input": {"type": "string", "description": "input description"}}
|
||||||
"string_input": {"type": "string", "description": "input description"}
|
|
||||||
}
|
|
||||||
output_type = "string"
|
output_type = "string"
|
||||||
|
|
||||||
def useless_method(self):
|
def useless_method(self):
|
||||||
|
@ -300,9 +293,7 @@ class ToolTests(unittest.TestCase):
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def forward(
|
def forward(self, location: str, celsius: Optional[bool] = False) -> str:
|
||||||
self, location: str, celsius: Optional[bool] = False
|
|
||||||
) -> str:
|
|
||||||
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
|
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
|
||||||
|
|
||||||
GetWeatherTool()
|
GetWeatherTool()
|
||||||
|
@ -340,9 +331,7 @@ class ToolTests(unittest.TestCase):
|
||||||
}
|
}
|
||||||
output_type = "string"
|
output_type = "string"
|
||||||
|
|
||||||
def forward(
|
def forward(self, location: str, celsius: Optional[bool] = False) -> str:
|
||||||
self, location: str, celsius: Optional[bool] = False
|
|
||||||
) -> str:
|
|
||||||
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
|
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
|
||||||
|
|
||||||
GetWeatherTool()
|
GetWeatherTool()
|
||||||
|
@ -410,9 +399,7 @@ def mock_smolagents_adapter():
|
||||||
|
|
||||||
|
|
||||||
class TestToolCollection:
|
class TestToolCollection:
|
||||||
def test_from_mcp(
|
def test_from_mcp(self, mock_server_parameters, mock_mcp_adapt, mock_smolagents_adapter):
|
||||||
self, mock_server_parameters, mock_mcp_adapt, mock_smolagents_adapter
|
|
||||||
):
|
|
||||||
with ToolCollection.from_mcp(mock_server_parameters) as tool_collection:
|
with ToolCollection.from_mcp(mock_server_parameters) as tool_collection:
|
||||||
assert isinstance(tool_collection, ToolCollection)
|
assert isinstance(tool_collection, ToolCollection)
|
||||||
assert len(tool_collection.tools) == 2
|
assert len(tool_collection.tools) == 2
|
||||||
|
@ -440,9 +427,5 @@ class TestToolCollection:
|
||||||
|
|
||||||
with ToolCollection.from_mcp(mcp_server_params) as tool_collection:
|
with ToolCollection.from_mcp(mcp_server_params) as tool_collection:
|
||||||
assert len(tool_collection.tools) == 1, "Expected 1 tool"
|
assert len(tool_collection.tools) == 1, "Expected 1 tool"
|
||||||
assert tool_collection.tools[0].name == "echo_tool", (
|
assert tool_collection.tools[0].name == "echo_tool", "Expected tool name to be 'echo_tool'"
|
||||||
"Expected tool name to be 'echo_tool'"
|
assert tool_collection.tools[0](text="Hello") == "Hello", "Expected tool to echo the input text"
|
||||||
)
|
|
||||||
assert tool_collection.tools[0](text="Hello") == "Hello", (
|
|
||||||
"Expected tool to echo the input text"
|
|
||||||
)
|
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from smolagents.utils import parse_code_blobs
|
from smolagents.utils import parse_code_blobs
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
ROOT = Path(__file__).parent.parent
|
ROOT = Path(__file__).parent.parent
|
||||||
|
|
||||||
TESTS_FOLDER = ROOT / "tests"
|
TESTS_FOLDER = ROOT / "tests"
|
||||||
|
@ -37,11 +38,7 @@ def check_tests_in_ci():
|
||||||
if path.name.startswith("test_")
|
if path.name.startswith("test_")
|
||||||
]
|
]
|
||||||
ci_workflow_file_content = CI_WORKFLOW_FILE.read_text()
|
ci_workflow_file_content = CI_WORKFLOW_FILE.read_text()
|
||||||
missing_test_files = [
|
missing_test_files = [test_file for test_file in test_files if test_file not in ci_workflow_file_content]
|
||||||
test_file
|
|
||||||
for test_file in test_files
|
|
||||||
if test_file not in ci_workflow_file_content
|
|
||||||
]
|
|
||||||
if missing_test_files:
|
if missing_test_files:
|
||||||
print(
|
print(
|
||||||
"❌ Some test files seem to be ignored in the CI:\n"
|
"❌ Some test files seem to be ignored in the CI:\n"
|
||||||
|
|
Loading…
Reference in New Issue