Make RAG example extremely fast with BM25

This commit is contained in:
Aymeric 2024-12-26 16:19:31 +01:00
parent eecd728668
commit 1abaf69b67
6 changed files with 40 additions and 102 deletions

View File

@ -78,7 +78,7 @@ The `preview` command only works with existing doc files. When you add a complet
Accepted files are Markdown (.md). Accepted files are Markdown (.md).
Create a file with its extension and put it in the source directory. You can then link it to the toc-tree by putting Create a file with its extension and put it in the source directory. You can then link it to the toc-tree by putting
the filename without the extension in the [`_toctree.yml`](https://github.com/huggingface/agents/blob/main/docs/source/_toctree.yml) file. the filename without the extension in the [`_toctree.yml`](https://github.com/huggingface/smolagents/blob/main/docs/source/_toctree.yml) file.
## Renaming section headers and moving sections ## Renaming section headers and moving sections
@ -108,7 +108,7 @@ For an example of a rich moved section set please see the very end of [the trans
## Writing Documentation - Specification ## Writing Documentation - Specification
The `huggingface/agents` documentation follows the The `huggingface/smolagents` documentation follows the
[Google documentation](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html) style for docstrings, [Google documentation](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html) style for docstrings,
although we can write them directly in Markdown. although we can write them directly in Markdown.
@ -123,7 +123,7 @@ Make sure to put your new file under the proper section. If you have a doubt, fe
### Translating ### Translating
When translating, refer to the guide at [./TRANSLATING.md](https://github.com/huggingface/agents/blob/main/docs/TRANSLATING.md). When translating, refer to the guide at [./TRANSLATING.md](https://github.com/huggingface/smolagents/blob/main/docs/TRANSLATING.md).
### Writing source documentation ### Writing source documentation

View File

@ -52,14 +52,10 @@ Then prepare the knowledge base by processing the dataset and storing it into a
We use [LangChain](https://python.langchain.com/docs/introduction/) for its excellent vector database utilities. We use [LangChain](https://python.langchain.com/docs/introduction/) for its excellent vector database utilities.
```py ```py
import time
import datasets import datasets
from tqdm import tqdm
from transformers import AutoTokenizer
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS, DistanceStrategy from langchain_community.retrievers import BM25Retriever
from langchain_community.embeddings import HuggingFaceEmbeddings
knowledge_base = datasets.load_dataset("m-ric/huggingface_doc", split="train") knowledge_base = datasets.load_dataset("m-ric/huggingface_doc", split="train")
knowledge_base = knowledge_base.filter(lambda row: row["source"].startswith("huggingface/transformers")) knowledge_base = knowledge_base.filter(lambda row: row["source"].startswith("huggingface/transformers"))
@ -69,47 +65,17 @@ source_docs = [
for doc in knowledge_base for doc in knowledge_base
] ]
embedding_model = "TaylorAI/gte-tiny" text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer( chunk_overlap=50,
AutoTokenizer.from_pretrained(embedding_model),
chunk_size=200,
chunk_overlap=20,
add_start_index=True, add_start_index=True,
strip_whitespace=True, strip_whitespace=True,
separators=["\n\n", "\n", ".", " ", ""], separators=["\n\n", "\n", ".", " ", ""],
) )
docs_processed = text_splitter.split_documents(source_docs)
# Split docs and keep only unique ones
print("Splitting documents...")
docs_processed = []
unique_texts = {}
for doc in tqdm(source_docs):
new_docs = text_splitter.split_documents([doc])
for new_doc in new_docs:
if new_doc.page_content not in unique_texts:
unique_texts[new_doc.page_content] = True
docs_processed.append(new_doc)
print(
"Embedding documents... This could take a few minutes."
)
t0 = time.time()
embedding_model = HuggingFaceEmbeddings(
model_name=embedding_model,
show_progress=True
)
vectordb = FAISS.from_documents(
documents=docs_processed,
embedding=embedding_model,
distance_strategy=DistanceStrategy.COSINE,
)
t1 = time.time()
print(f"VectorDB embedded in {(t1-t0):.2f} seconds")
``` ```
If you want to improve performance, head to the [MTEB Leaderboard](https://huggingface.co/spaces/mteb/leaderboard) to select a bigger model for your embeddings: here we selected a small one for the sake of speed.
Now the database is ready. Building the embeddings for each document snippet took a few minutes, but now they're ready to be used in a split second. Now the documents are ready.
So lets build our agentic RAG system! So lets build our agentic RAG system!
@ -122,7 +88,7 @@ from smolagents import Tool
class RetrieverTool(Tool): class RetrieverTool(Tool):
name = "retriever" name = "retriever"
description = "Using semantic similarity, retrieves some documents from the knowledge base that have the closest embeddings to the input query." description = "Uses semantic search to retrieve the parts of transformers documentation that could be most relevant to answer your query."
inputs = { inputs = {
"query": { "query": {
"type": "string", "type": "string",
@ -131,27 +97,31 @@ class RetrieverTool(Tool):
} }
output_type = "string" output_type = "string"
def __init__(self, vectordb, **kwargs): def __init__(self, docs, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.vectordb = vectordb self.retriever = BM25Retriever.from_documents(
docs, k=10
)
def forward(self, query: str) -> str: def forward(self, query: str) -> str:
assert isinstance(query, str), "Your search query must be a string" assert isinstance(query, str), "Your search query must be a string"
docs = self.vectordb.similarity_search( docs = self.retriever.invoke(
query, query,
k=10,
) )
return "\nRetrieved documents:\n" + "".join( return "\nRetrieved documents:\n" + "".join(
[ [
f"===== Document {str(i)} =====\n" + doc.page_content f"\n\n===== Document {str(i)} =====\n" + doc.page_content
for i, doc in enumerate(docs) for i, doc in enumerate(docs)
] ]
) )
```
Now its straightforward to create an agent that leverages this tool! retriever_tool = RetrieverTool(docs_processed)
```
We have used BM25, a classic retrieval method, because it's lightning fast to setup.
To improve retrieval accuracy, you could use replace BM25 with semantic search using vector representations for documents: thus you can head to the [MTEB Leaderboard](https://huggingface.co/spaces/mteb/leaderboard) to select a good embedding model.
Now its straightforward to create an agent that leverages this `retriever_tool`!
The agent will need these arguments upon initialization: The agent will need these arguments upon initialization:
- `tools`: a list of tools that the agent will be able to call. - `tools`: a list of tools that the agent will be able to call.
@ -167,7 +137,6 @@ _Note:_ The Inference API hosts models based on various criteria, and deployed m
```py ```py
from smolagents import HfApiModel, CodeAgent from smolagents import HfApiModel, CodeAgent
retriever_tool = RetrieverTool(vectordb)
agent = CodeAgent( agent = CodeAgent(
tools=[retriever_tool], model=HfApiModel("meta-llama/Llama-3.3-70B-Instruct"), max_iterations=4, verbose=True tools=[retriever_tool], model=HfApiModel("meta-llama/Llama-3.3-70B-Instruct"), max_iterations=4, verbose=True
) )
@ -178,7 +147,7 @@ Upon initializing the CodeAgent, it has been automatically given a default syste
Then when its `.run()` method is launched, the agent takes care of calling the LLM engine, and executing the tool calls, all in a loop that ends only when tool `final_answer` is called with the final answer as its argument. Then when its `.run()` method is launched, the agent takes care of calling the LLM engine, and executing the tool calls, all in a loop that ends only when tool `final_answer` is called with the final answer as its argument.
```py ```py
agent_output = agent.run("How can I push a model to the Hub?") agent_output = agent.run("For a transformers model training, which is slower, the forward or the backward pass?")
print("Final output:") print("Final output:")
print(agent_output) print(agent_output)

View File

@ -2,4 +2,4 @@
FROM e2bdev/code-interpreter:latest FROM e2bdev/code-interpreter:latest
# Install dependencies and customize sandbox # Install dependencies and customize sandbox
RUN pip install git+https://github.com/huggingface/agents.git RUN pip install git+https://github.com/huggingface/smolagents.git

View File

@ -1,59 +1,28 @@
# from huggingface_hub import login # from huggingface_hub import login
# login() # login()
import time
import datasets import datasets
from tqdm import tqdm
from transformers import AutoTokenizer
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS, DistanceStrategy from langchain_community.retrievers import BM25Retriever
from langchain_community.embeddings import HuggingFaceEmbeddings
knowledge_base = datasets.load_dataset("m-ric/huggingface_doc", split="train") knowledge_base = datasets.load_dataset("m-ric/huggingface_doc", split="train")
knowledge_base = knowledge_base.filter(lambda row: row["source"].startswith("huggingface/transformers")) knowledge_base = knowledge_base.filter(lambda row: row["source"].startswith("huggingface/transformers"))
embedding_model = "TaylorAI/gte-tiny"
source_docs = [ source_docs = [
Document(page_content=doc["text"], metadata={"source": doc["source"].split("/")[1]}) Document(page_content=doc["text"], metadata={"source": doc["source"].split("/")[1]})
for doc in knowledge_base for doc in knowledge_base
] ]
text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer( text_splitter = RecursiveCharacterTextSplitter(
AutoTokenizer.from_pretrained(embedding_model), chunk_size=500,
chunk_size=200, chunk_overlap=50,
chunk_overlap=20,
add_start_index=True, add_start_index=True,
strip_whitespace=True, strip_whitespace=True,
separators=["\n\n", "\n", ".", " ", ""], separators=["\n\n", "\n", ".", " ", ""],
) )
docs_processed = text_splitter.split_documents(source_docs)
# Split docs and keep only unique ones
print("Splitting documents...")
docs_processed = []
unique_texts = {}
for doc in tqdm(source_docs):
new_docs = text_splitter.split_documents([doc])
for new_doc in new_docs:
if new_doc.page_content not in unique_texts:
unique_texts[new_doc.page_content] = True
docs_processed.append(new_doc)
print(
"Embedding documents... This could take a few minutes."
)
t0 = time.time()
embedding_model = HuggingFaceEmbeddings(
model_name=embedding_model,
show_progress=True
)
vectordb = FAISS.from_documents(
documents=docs_processed,
embedding=embedding_model,
distance_strategy=DistanceStrategy.COSINE,
)
t1 = time.time()
print(f"VectorDB embedded in {(t1-t0):.2f} seconds")
from smolagents import Tool from smolagents import Tool
@ -68,33 +37,33 @@ class RetrieverTool(Tool):
} }
output_type = "string" output_type = "string"
def __init__(self, vectordb, **kwargs): def __init__(self, docs, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.vectordb = vectordb self.retriever = BM25Retriever.from_documents(
docs, k=10
)
def forward(self, query: str) -> str: def forward(self, query: str) -> str:
assert isinstance(query, str), "Your search query must be a string" assert isinstance(query, str), "Your search query must be a string"
docs = self.vectordb.similarity_search( docs = self.retriever.invoke(
query, query,
k=10,
) )
return "\nRetrieved documents:\n" + "".join( return "\nRetrieved documents:\n" + "".join(
[ [
f"===== Document {str(i)} =====\n" + doc.page_content f"\n\n===== Document {str(i)} =====\n" + doc.page_content
for i, doc in enumerate(docs) for i, doc in enumerate(docs)
] ]
) )
from smolagents import HfApiModel, CodeAgent from smolagents import HfApiModel, CodeAgent
retriever_tool = RetrieverTool(vectordb) retriever_tool = RetrieverTool(docs_processed)
agent = CodeAgent( agent = CodeAgent(
tools=[retriever_tool], model=HfApiModel("meta-llama/Llama-3.3-70B-Instruct"), max_iterations=4, verbose=True tools=[retriever_tool], model=HfApiModel("meta-llama/Llama-3.3-70B-Instruct"), max_iterations=4, verbose=True
) )
agent_output = agent.run("For a transformers model training, which is faster, the forward or the backward pass?") agent_output = agent.run("For a transformers model training, which is slower, the forward or the backward pass?")
print("Final output:") print("Final output:")
print(agent_output) print(agent_output)

View File

@ -910,7 +910,7 @@ class CodeAgent(MultiStepAgent):
align="left", align="left",
style="orange", style="orange",
), ),
Syntax(llm_output, lexer="markdown", theme="github-dark"), Syntax(llm_output, lexer="markdown", theme="github-dark", word_wrap=True),
) )
) )

View File

@ -36,7 +36,7 @@ class E2BExecutor:
# TODO: validate installing agents package or not # TODO: validate installing agents package or not
# print("Installing agents package on remote executor...") # print("Installing agents package on remote executor...")
# self.sbx.commands.run( # self.sbx.commands.run(
# "pip install git+https://github.com/huggingface/agents.git", # "pip install git+https://github.com/huggingface/smolagents.git",
# timeout=300 # timeout=300
# ) # )
# print("Installation of agents package finished.") # print("Installation of agents package finished.")