Attach metadata to steps

You might want to log metadata and have that be attached to a specific step during the course of your work. This is possible by using the log_step_metadata method. This method allows you to attach a dictionary of key-value pairs as metadata to a step. The metadata can be any JSON-serializable value, including custom classes such as Uri, Path, DType, and StorageSize.

You can call this method from within a step or from outside. If you call it from within it will attach the metadata to the step and run that is currently being executed.

from zenml import step, log_step_metadata, ArtifactConfig, get_step_context
from typing import Annotated
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.base import ClassifierMixin

@step
def train_model(dataset: pd.DataFrame) -> Annotated[ClassifierMixin, ArtifactConfig(name="sklearn_classifier", is_model_artifact=True)]:
    """Train a model"""
    # Fit the model and compute metrics
    classifier = RandomForestClassifier().fit(dataset)
    accuracy, precision, recall = ...

    # Log metadata at the step level
    # This associates the metadata with the ZenML step run
    log_step_metadata(
        metadata={
            "evaluation_metrics": {
                "accuracy": accuracy,
                "precision": precision,
                "recall": recall
            }
        },
    )
    return classifier

If you call it from outside you can attach the metadata to a specific step run from any pipeline and step. This is useful if you want to attach the metadata after you've run the step.

from zenml import log_step_metadata
# run some step

# subsequently log the metadata for the step
log_step_metadata(
    metadata={
        "some_metadata": {"a_number": 3}
    },
    pipeline_name_id_or_prefix="my_pipeline",
    step_name="my_step",
    run_id="my_step_run_id"
)

Fetching logged metadata

Once metadata has been logged in an artifact, model, we can easily fetch the metadata with the ZenML Client:

from zenml.client import Client

client = Client()
step = client.get_pipeline_run().steps["step_name"]

print(step.run_metadata["metadata_key"].value)

Last updated