# Finetuning embeddings with Sentence Transformers

We now have a dataset that we can use to finetune our embeddings. You can[inspect the positive and negative examples](https://huggingface.co/datasets/zenml/rag_qa_embedding_questions_0_60_0_distilabel) on the Hugging Face [datasets page](https://huggingface.co/datasets/zenml/rag_qa_embedding_questions_0_60_0_distilabel) since our previous pipeline pushed the data there.

![Synthetic data generated with distilabel for embeddings finetuning](/files/q7gAZ4I584VcdJoDStuc)

Our pipeline for finetuning the embeddings is relatively simple. We'll do the following:

* load our data either from Hugging Face or [from Argilla via the ZenML annotation integration](https://docs.zenml.io/stacks/annotators/argilla)
* finetune our model using the [Sentence Transformers](https://www.sbert.net/) library
* evaluate the base and finetuned embeddings
* visualize the results of the evaluation

![Embeddings finetuning pipeline with Sentence Transformers and ZenML](/files/Dtxr8zBvsV3FE2dzyrNv)

### Loading data

By default the pipeline will load the data from our Hugging Face dataset. If you've annotated your data in Argilla, you can load the data from there instead. You'll just need to pass an `--argilla` flag to the Python invocation when you're running the pipeline like so:

```bash
python run.py --embeddings --argilla
```

This assumes that you've set up an Argilla annotator in your stack. The code checks for the annotator and downloads the data that was annotated in Argilla. Please see our [guide to using the Argilla integration with ZenML](https://docs.zenml.io/stacks/annotators/argilla) for more details.

### Finetuning with Sentence Transformers

The `finetune` step in the pipeline is responsible for finetuning the embeddings model using the Sentence Transformers library. Let's break down the key aspects of this step:

1. **Model Loading**: The code loads the base model (`EMBEDDINGS_MODEL_ID_BASELINE`) using the Sentence Transformers library. It utilizes the SDPA (Self-Distilled Pruned Attention) implementation for efficient training with Flash Attention 2.
2. **Loss Function**: The finetuning process employs a custom loss function called `MatryoshkaLoss`. This loss function is a wrapper around the `MultipleNegativesRankingLoss` provided by Sentence Transformers. The Matryoshka approach involves training the model with different embedding dimensions simultaneously. It allows the model to learn embeddings at various granularities, improving its performance across different embedding sizes.
3. **Dataset Preparation**: The training dataset is loaded from the provided `dataset` parameter. The code saves the training data to a temporary JSON file and then loads it using the Hugging Face `load_dataset` function.
4. **Evaluator**: An evaluator is created using the `get_evaluator` function. The evaluator is responsible for assessing the model's performance during training.
5. **Training Arguments**: The code sets up the training arguments using the `SentenceTransformerTrainingArguments` class. It specifies various hyperparameters such as the number of epochs, batch size, learning rate, optimizer, precision (TF32 and BF16), and evaluation strategy.
6. **Trainer**: The `SentenceTransformerTrainer` is initialized with the model, training arguments, training dataset, loss function, and evaluator. The trainer handles the training process. The `trainer.train()` method is called to start the finetuning process. The model is trained for the specified number of epochs using the provided hyperparameters.
7. **Model Saving**: After training, the finetuned model is pushed to the Hugging Face Hub using the `trainer.model.push_to_hub()` method. The model is saved with the specified ID (`EMBEDDINGS_MODEL_ID_FINE_TUNED`).
8. **Metadata Logging**: The code logs relevant metadata about the training process, including the training parameters, hardware information, and accelerator details.
9. **Model Rehydration**: To handle materialization errors, the code saves the trained model to a temporary file, loads it back into a new`SentenceTransformer` instance, and returns the rehydrated model.

(*Thanks and credit to Phil Schmid for* [*his tutorial on finetuning embeddings*](https://www.philschmid.de/fine-tune-embedding-model-for-rag) *with Sentence* *Transformers and a Matryoshka loss function. This project uses many ideas and* *some code from his implementation.*)

### Finetuning in code

Here's a simplified code snippet highlighting the key parts of the finetuning process:

```python
# Load the base model
model = SentenceTransformer(EMBEDDINGS_MODEL_ID_BASELINE)
# Define the loss function
train_loss = MatryoshkaLoss(model, MultipleNegativesRankingLoss(model))
# Prepare the training dataset
train_dataset = load_dataset("json", data_files=train_dataset_path)
# Set up the training arguments
args = SentenceTransformerTrainingArguments(...)
# Create the trainer
trainer = SentenceTransformerTrainer(model, args, train_dataset, train_loss)
# Start training
trainer.train()
# Save the finetuned model
trainer.model.push_to_hub(EMBEDDINGS_MODEL_ID_FINE_TUNED)
```

The finetuning process leverages the capabilities of the Sentence Transformers library to efficiently train the embeddings model. The Matryoshka approach allows for learning embeddings at different dimensions simultaneously, enhancing the model's performance across various embedding sizes.

Our model is finetuned, saved in the Hugging Face Hub for easy access and reference in subsequent steps, but also versioned and tracked within ZenML for full observability. At this point the pipeline will evaluate the base and finetuned embeddings and visualize the results.

<figure><img src="https://static.scarf.sh/a.png?x-pxid=f0b4f458-0a54-4fcd-aa95-d5ee424815bc" alt="ZenML Scarf"><figcaption></figcaption></figure>


---

# Agent Instructions: Querying This Documentation

If you need additional information that is not directly available in this page, you can query the documentation dynamically by asking a question.

Perform an HTTP GET request on the current page URL with the `ask` query parameter:

```
GET https://docs.zenml.io/user-guides/llmops-guide/finetuning-embeddings/finetuning-embeddings-with-sentence-transformers.md?ask=<question>
```

The question should be specific, self-contained, and written in natural language.
The response will contain a direct answer to the question and relevant excerpts and sources from the documentation.

Use this mechanism when the answer is not explicitly present in the current page, you need clarification or additional context, or you want to retrieve related documentation sections.
