Implementing reranking in ZenML
Learn how to implement reranking in ZenML.
We already have a working RAG pipeline, so inserting a reranker into the pipeline is relatively straightforward. The reranker will take the retrieved documents from the initial retrieval step and reorder them in terms of the query that was used to retrieve them.
How and where to add reranking
We'll use the rerankers
package to handle the reranking process in our RAG inference pipeline. It's a relatively low-cost (in terms of technical debt and complexity) and lightweight dependency to add into our pipeline. It offers an interface to most of the model types that are commonly used for reranking and means we don't have to worry about the specifics of each model.
This package provides a Reranker
abstract class that you can use to define your own reranker. You can also use the provided implementations to add reranking to your pipeline. The reranker takes the query and a list of retrieved documents as input and outputs a reordered list of documents based on the reranking scores. Here's a toy example:
And results will look something like this:
We can see that the reranker has reordered the documents based on the reranking scores, with the most relevant document appearing at the top of the list. The texts about sport are at the top and the less relevant ones about animals are down at the bottom.
We specified that we want a cross-encoder
reranker, but you can also use other reranker models from the Hugging Face Hub, use API-driven reranker models (from Jina or Cohere, for example), or even define your own reranker model. Read their documentation to see how to use these different configurations.
In our case, we can simply add a helper function that can optionally be invoked when we want to use the reranker:
This function takes a query and a list of documents (each document is a tuple of content and URL) and reranks the documents based on the query. It returns a list of tuples, where each tuple contains the reranked document text and the URL of the original document. We use the flashrank
model from the rerankers
package by default as it appeared to be a good choice for our use case during development.
This function then gets used in tests in the following way:
We get the embeddings for the question being passed into the function and connect to our PostgreSQL database. If we're using reranking, we get the top 20 documents similar to our query and rerank them using the rerank_documents
helper function. We then extract the URLs from the reranked documents and return them. Note that we only return 5 URLs, but in the case of reranking we get a larger number of documents and URLs back from the database to pass to our reranker, but in the end we always choose the top five reranked documents to return.
Now that we've added reranking to our pipeline, we can evaluate the performance of our reranker and see how it affects the quality of the retrieved documents.
Code Example
To explore the full code, visit the Complete Guide repository and for this section, particularly the eval_retrieval.py
file.
Last updated