smolagents/examples/rag.py

101 lines
3.1 KiB
Python

# from huggingface_hub import login
# login()
import time
import datasets
from tqdm import tqdm
from transformers import AutoTokenizer
from langchain.docstore.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS, DistanceStrategy
from langchain_community.embeddings import HuggingFaceEmbeddings
knowledge_base = datasets.load_dataset("m-ric/huggingface_doc", split="train")
knowledge_base = knowledge_base.filter(lambda row: row["source"].startswith("huggingface/transformers"))
embedding_model = "TaylorAI/gte-tiny"
source_docs = [
Document(page_content=doc["text"], metadata={"source": doc["source"].split("/")[1]})
for doc in knowledge_base
]
text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
AutoTokenizer.from_pretrained(embedding_model),
chunk_size=200,
chunk_overlap=20,
add_start_index=True,
strip_whitespace=True,
separators=["\n\n", "\n", ".", " ", ""],
)
# Split docs and keep only unique ones
print("Splitting documents...")
docs_processed = []
unique_texts = {}
for doc in tqdm(source_docs):
new_docs = text_splitter.split_documents([doc])
for new_doc in new_docs:
if new_doc.page_content not in unique_texts:
unique_texts[new_doc.page_content] = True
docs_processed.append(new_doc)
print(
"Embedding documents... This could take a few minutes."
)
t0 = time.time()
embedding_model = HuggingFaceEmbeddings(
model_name=embedding_model,
show_progress=True
)
vectordb = FAISS.from_documents(
documents=docs_processed,
embedding=embedding_model,
distance_strategy=DistanceStrategy.COSINE,
)
t1 = time.time()
print(f"VectorDB embedded in {(t1-t0):.2f} seconds")
from smolagents import Tool
class RetrieverTool(Tool):
name = "retriever"
description = "Uses semantic search to retrieve the parts of transformers documentation that could be most relevant to answer your query."
inputs = {
"query": {
"type": "string",
"description": "The query to perform. This should be semantically close to your target documents. Use the affirmative form rather than a question.",
}
}
output_type = "string"
def __init__(self, vectordb, **kwargs):
super().__init__(**kwargs)
self.vectordb = vectordb
def forward(self, query: str) -> str:
assert isinstance(query, str), "Your search query must be a string"
docs = self.vectordb.similarity_search(
query,
k=10,
)
return "\nRetrieved documents:\n" + "".join(
[
f"===== Document {str(i)} =====\n" + doc.page_content
for i, doc in enumerate(docs)
]
)
from smolagents import HfApiModel, CodeAgent
retriever_tool = RetrieverTool(vectordb)
agent = CodeAgent(
tools=[retriever_tool], model=HfApiModel("meta-llama/Llama-3.3-70B-Instruct"), max_iterations=4, verbose=True
)
agent_output = agent.run("For a transformers model training, which is faster, the forward or the backward pass?")
print("Final output:")
print(agent_output)