Embeddings generation

Generate embeddings to improve retrieval performance.

This is an older version of the ZenML documentation. To read and view the latest version please visit this up-to-date URL.

Generating Embeddings for Retrieval

In this section, we'll explore how to generate embeddings for your data to improve retrieval performance in your RAG pipeline. Embeddings are a crucial part of the retrieval mechanism in RAG, as they represent the data in a high-dimensional space where similar items are closer together. By generating embeddings for your data, you can enhance the retrieval capabilities of your RAG pipeline and provide more accurate and relevant responses to user queries.

Embeddings are vector representations of data that capture the semantic meaning and context of the data in a high-dimensional space. They are generated using machine learning models, such as word embeddings or sentence embeddings, that learn to encode the data in a way that preserves its underlying structure and relationships. Embeddings are commonly used in natural language processing (NLP) tasks, such as text classification, sentiment analysis, and information retrieval, to represent textual data in a format that is suitable for computational processing.

The whole purpose of the embeddings is to allow us to quickly find the small chunks that are most relevant to our input query at inference time. An even simpler way of doing this would be to just to search for some keywords in the query and hope that they're also represented in the chunks. However, this approach is not very robust and may not work well for more complex queries or longer documents. By using embeddings, we can capture the semantic meaning and context of the data and retrieve the most relevant chunks based on their similarity to the query.

We're using the sentence-transformers library to generate embeddings for our data. This library provides pre-trained models for generating sentence embeddings that capture the semantic meaning of the text. It's an open-source library that is easy to use and provides high-quality embeddings for a wide range of NLP tasks.

from typing import Annotated, List
import numpy as np
from sentence_transformers import SentenceTransformer
from zenml import ArtifactConfig, log_artifact_metadata, step

@step
def generate_embeddings(
    split_documents: List[str],
) -> Annotated[np.ndarray, ArtifactConfig(name="embeddings")]:
    """Generates embeddings for a list of split documents using a SentenceTransformer model."""
    try:
        model = SentenceTransformer("sentence-transformers/all-MiniLM-L12-v2")

        log_artifact_metadata(
            artifact_name="embeddings",
            metadata={
                "embedding_type": "sentence-transformers/all-MiniLM-L12-v2",
                "embedding_dimensionality": 384,
            },
        )
        return model.encode(split_documents)
    except Exception as e:
        logger.error(f"Error in generate_embeddings: {e}")
        raise

There are smaller embeddings models if we cared a lot about speed, and larger ones (with more dimensions) if we wanted to boost our ability to retrieve more relevant chunks. The model we're using here is on the smaller side, but it should work well for our use case. The embeddings generated by this model have a dimensionality of 384, which means that each embedding is represented as a 384-dimensional vector in the high-dimensional space.

We can use dimensionality reduction functionality in umap and scikit-learn to represent the 384 dimensions of our embeddings in two-dimensional space. This allows us to visualize the embeddings and see how similar chunks are clustered together based on their semantic meaning and context. We can also use this visualization to identify patterns and relationships in the data that can help us improve the retrieval performance of our RAG pipeline. It's worth trying both UMAP and t-SNE to see which one works best for our use case since they both have somewhat different representations of the data and reduction algorithms, as you'll see.

import matplotlib.pyplot as plt
import numpy as np
from sklearn.manifold import TSNE
import umap
from zenml.client import Client

artifact = Client().get_artifact_version('EMBEDDINGS_ARTIFACT_UUID_GOES_HERE')
embeddings = artifact.load()

# Dimensionality reduction using t-SNE
def tsne_visualization(embeddings):
    tsne = TSNE(n_components=2, random_state=42)
    embeddings_2d = tsne.fit_transform(embeddings)
    return embeddings_2d

# Dimensionality reduction using UMAP
def umap_visualization(embeddings):
    umap_2d = umap.UMAP(n_components=2, random_state=42)
    embeddings_2d = umap_2d.fit_transform(embeddings)
    return embeddings_2d

# Visualize embeddings using t-SNE
tsne_embeddings = tsne_visualization(embeddings)
plt.figure(figsize=(8, 8))
plt.scatter(tsne_embeddings[:, 0], tsne_embeddings[:, 1])
plt.title("t-SNE Visualization")
plt.show()

# Visualize embeddings using UMAP
umap_embeddings = umap_visualization(embeddings)
plt.figure(figsize=(8, 8))
plt.scatter(umap_embeddings[:, 0], umap_embeddings[:, 1])
plt.title("UMAP Visualization")
plt.show()

At a later stage we could add some semantics to the chunks to represent certain labels or types of documentation to ensure that the semantic meanings are being represented in our embedding space. For now, this was just to show that you can visualize the embeddings and see how similar chunks are clustered together based on their semantic meaning and context.

So this step iterates through all the chunks and generates embeddings representing each piece of text. These embeddings are then stored as an artifact in the ZenML artifact store as a NumPy array. We separate this generation from the point where we upload those embeddings to the vector database to keep the pipeline modular and flexible; in the future we might want to use a different vector database so we can just swap out the upload step without having to re-generate the embeddings.

In the next section, we'll explore how to store these embeddings in a vector database to enable fast and efficient retrieval of relevant chunks at inference time.

Code Example

To explore the full code, visit the Complete Guide repository. The embeddings generation step can be found here.

Last updated