Develop a Custom Model Registry

Learning how to develop a custom model registry.

Before diving into the specifics of this component type, it is beneficial to familiarize yourself with our general guide to writing custom component flavors in ZenML. This guide provides an essential understanding of ZenML's component flavor concepts.

Base Abstraction

The BaseModelRegistry is the abstract base class that needs to be subclassed in order to create a custom component that can be used to register and retrieve models. As model registries can come in many shapes and forms, the base class exposes a deliberately basic and generic interface:

from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Dict, List, Optional, Type, cast

from pydantic import BaseModel, Field, root_validator

from zenml.enums import StackComponentType
from zenml.stack import Flavor, StackComponent
from zenml.stack.stack_component import StackComponentConfig


class BaseModelRegistryConfig(StackComponentConfig):
    """Base config for model registries."""


class BaseModelRegistry(StackComponent, ABC):
    """Base class for all ZenML model registries."""

    @property
    def config(self) -> BaseModelRegistryConfig:
        """Returns the config of the model registry."""
        return cast(BaseModelRegistryConfig, self._config)

    # ---------
    # Model Registration Methods
    # ---------

    @abstractmethod
    def register_model(
            self,
            name: str,
            description: Optional[str] = None,
            tags: Optional[Dict[str, str]] = None,
    ) -> RegisteredModel:
        """Registers a model in the model registry."""

    @abstractmethod
    def delete_model(
            self,
            name: str,
    ) -> None:
        """Deletes a registered model from the model registry."""

    @abstractmethod
    def update_model(
            self,
            name: str,
            description: Optional[str] = None,
            tags: Optional[Dict[str, str]] = None,
    ) -> RegisteredModel:
        """Updates a registered model in the model registry."""

    @abstractmethod
    def get_model(self, name: str) -> RegisteredModel:
        """Gets a registered model from the model registry."""

    @abstractmethod
    def list_models(
            self,
            name: Optional[str] = None,
            tags: Optional[Dict[str, str]] = None,
    ) -> List[RegisteredModel]:
        """Lists all registered models in the model registry."""

    # ---------
    # Model Version Methods
    # ---------

    @abstractmethod
    def register_model_version(
            self,
            name: str,
            description: Optional[str] = None,
            tags: Optional[Dict[str, str]] = None,
            model_source_uri: Optional[str] = None,
            version: Optional[str] = None,
            description: Optional[str] = None,
            tags: Optional[Dict[str, str]] = None,
            metadata: Optional[Dict[str, str]] = None,
            zenml_version: Optional[str] = None,
            zenml_run_name: Optional[str] = None,
            zenml_pipeline_name: Optional[str] = None,
            zenml_step_name: Optional[str] = None,
            **kwargs: Any,
    ) -> RegistryModelVersion:
        """Registers a model version in the model registry."""

    @abstractmethod
    def delete_model_version(
            self,
            name: str,
            version: str,
    ) -> None:
        """Deletes a model version from the model registry."""

    @abstractmethod
    def update_model_version(
            self,
            name: str,
            version: str,
            description: Optional[str] = None,
            tags: Optional[Dict[str, str]] = None,
            stage: Optional[ModelVersionStage] = None,
    ) -> RegistryModelVersion:
        """Updates a model version in the model registry."""

    @abstractmethod
    def list_model_versions(
            self,
            name: Optional[str] = None,
            model_source_uri: Optional[str] = None,
            tags: Optional[Dict[str, str]] = None,
            **kwargs: Any,
    ) -> List[RegistryModelVersion]:
        """Lists all model versions for a registered model."""

    @abstractmethod
    def get_model_version(self, name: str, version: str) -> RegistryModelVersion:
        """Gets a model version for a registered model."""

    @abstractmethod
    def load_model_version(
            self,
            name: str,
            version: str,
            **kwargs: Any,
    ) -> Any:
        """Loads a model version from the model registry."""

    @abstractmethod
    def get_model_uri_artifact_store(
            self,
            model_version: RegistryModelVersion,
    ) -> str:
        """Gets the URI artifact store for a model version."""

This is a slimmed-down version of the base implementation which aims to highlight the abstraction layer. To see the full implementation and get the complete docstrings, please check the source code on GitHub .

Build your own custom model registry

If you want to create your own custom flavor for a model registry, you can follow the following steps:

  1. Learn more about the core concepts for the model registry here. Your custom model registry will be built on top of these concepts so it helps to be aware of them.

  2. Create a class that inherits from BaseModelRegistry and implements the abstract methods.

  3. Create a ModelRegistryConfig class that inherits from BaseModelRegistryConfig and adds any additional configuration parameters that you need.

  4. Bring the implementation and the configuration together by inheriting from the BaseModelRegistryFlavor class. Make sure that you give a name to the flavor through its abstract property.

Once you are done with the implementation, you can register it through the CLI with the following command:

zenml model-registry flavor register <IMAGE-BUILDER-FLAVOR-SOURCE-PATH>

For a full implementation example, please check out the MLFlowModelRegistry

ZenML Scarf

Last updated

Was this helpful?