Develop a Custom Model Registry
Learning how to develop a custom model registry.
Base Abstraction
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Dict, List, Optional, Type, cast
from pydantic import BaseModel, Field, root_validator
from zenml.enums import StackComponentType
from zenml.stack import Flavor, StackComponent
from zenml.stack.stack_component import StackComponentConfig
class BaseModelRegistryConfig(StackComponentConfig):
"""Base config for model registries."""
class BaseModelRegistry(StackComponent, ABC):
"""Base class for all ZenML model registries."""
@property
def config(self) -> BaseModelRegistryConfig:
"""Returns the config of the model registry."""
return cast(BaseModelRegistryConfig, self._config)
# ---------
# Model Registration Methods
# ---------
@abstractmethod
def register_model(
self,
name: str,
description: Optional[str] = None,
tags: Optional[Dict[str, str]] = None,
) -> RegisteredModel:
"""Registers a model in the model registry."""
@abstractmethod
def delete_model(
self,
name: str,
) -> None:
"""Deletes a registered model from the model registry."""
@abstractmethod
def update_model(
self,
name: str,
description: Optional[str] = None,
tags: Optional[Dict[str, str]] = None,
) -> RegisteredModel:
"""Updates a registered model in the model registry."""
@abstractmethod
def get_model(self, name: str) -> RegisteredModel:
"""Gets a registered model from the model registry."""
@abstractmethod
def list_models(
self,
name: Optional[str] = None,
tags: Optional[Dict[str, str]] = None,
) -> List[RegisteredModel]:
"""Lists all registered models in the model registry."""
# ---------
# Model Version Methods
# ---------
@abstractmethod
def register_model_version(
self,
name: str,
description: Optional[str] = None,
tags: Optional[Dict[str, str]] = None,
model_source_uri: Optional[str] = None,
version: Optional[str] = None,
description: Optional[str] = None,
tags: Optional[Dict[str, str]] = None,
metadata: Optional[Dict[str, str]] = None,
zenml_version: Optional[str] = None,
zenml_run_name: Optional[str] = None,
zenml_pipeline_name: Optional[str] = None,
zenml_step_name: Optional[str] = None,
**kwargs: Any,
) -> RegistryModelVersion:
"""Registers a model version in the model registry."""
@abstractmethod
def delete_model_version(
self,
name: str,
version: str,
) -> None:
"""Deletes a model version from the model registry."""
@abstractmethod
def update_model_version(
self,
name: str,
version: str,
description: Optional[str] = None,
tags: Optional[Dict[str, str]] = None,
stage: Optional[ModelVersionStage] = None,
) -> RegistryModelVersion:
"""Updates a model version in the model registry."""
@abstractmethod
def list_model_versions(
self,
name: Optional[str] = None,
model_source_uri: Optional[str] = None,
tags: Optional[Dict[str, str]] = None,
**kwargs: Any,
) -> List[RegistryModelVersion]:
"""Lists all model versions for a registered model."""
@abstractmethod
def get_model_version(self, name: str, version: str) -> RegistryModelVersion:
"""Gets a model version for a registered model."""
@abstractmethod
def load_model_version(
self,
name: str,
version: str,
**kwargs: Any,
) -> Any:
"""Loads a model version from the model registry."""
@abstractmethod
def get_model_uri_artifact_store(
self,
model_version: RegistryModelVersion,
) -> str:
"""Gets the URI artifact store for a model version."""Build your own custom model registry
Last updated
Was this helpful?