Creating your trainer


In ZenML, the model training is achieved through an interface called the BaseTrainerStep. This specific step interface features one main abstract method (run_fn) and three helper methods (input_fn, model_fn, test_fn) that handle different processes within the model training workflow.

# Overview of the BaseTrainerStep
class BaseTrainerStep(BaseStep):
def input_fn(self, *args, **kwargs):
def model_fn(self, *args, **kwargs):
def test_fn(self, *args, **kwargs):
def run_fn(self):


The input_fn is one of the aforementioned helper methods. As the name suggests, the motivation when using this function is to separate the flow of data preparation from the training. For instance, in the built-in TrainerSteps, it is designed to read the data using the parameter file_pattern which is a list containing the uris to the data and prepare a dataset that can be used with the implemented model architecture.

def input_fn(self, *args, **kwargs)


The next helper method in line is the model_fn. Similar to the input_fn, the aim of this method is to separate the model preparation from the training. The built-in TrainerSteps utilize this method to create an instance of the model architecture and return it.

def model_fn(self, *args, **kwargs)


The last helper method when designing a BaseTrainerStep is the test_fn. The goal here is to separate the computation of the test results once the training is completed. In the implementation of the built-in TrainerSteps, this method is using the trained model to compute the model output on the test splits and stores the results as an artifact.

Storing the test results within the context of a TrainerStep will be especially crucial in terms of post-training evaluation, because, it will allow us to utilize a model agnostic evaluator in the next step.

def test_fn(self, *args, **kwargs)


The run_fn is the only abstract method and it is where everything related to the training comes together. Within the context of this method, the required datasets will be created, training will be conducted and the evaluation and tests will follow.

def run_fn(self)

A quick example: the built-in TorchFeedForwardTrainer step

The following is an overview of the complete step. You can find the full code right here.

class BinaryClassifier(nn.Module):
def binary_acc(y_pred, y_test):
class TorchFeedForwardTrainer(BaseTrainerStep):
def input_fn(self,
file_patterns: List[Text]):
Method which creates the datasets for model training
dataset = torch_utils.TFRecordTorchDataset(file_patterns,
loader =,
return loader
def model_fn(self, train_dataset, eval_dataset):
Method which prepares the model instance
return BinaryClassifier()
def test_fn(self, model, dataset):
Method which computes and stores the test results
batch_list = []
for x, y, raw in dataset:
# start with an empty batch
batch = {}
# add the raw features with the transformed features and labels
# finally, add the output of the model
x_batch =[v for v in x.values()], dim=-1)
p = model(x_batch)
if isinstance(p, torch.Tensor):
batch.update({'output': p})
elif isinstance(p, dict):
elif isinstance(p, list):
{'output_{}'.format(i): v for i, v in enumerate(p)})
raise TypeError('Unknown output format!')
combined_batch = utils.combine_batch_results(batch_list)
return combined_batch
def run_fn(self):
Function which handles the model training
# Prepare the datasets
train_dataset = self.input_fn(train_split_patterns)
eval_dataset = self.input_fn(eval_split_patterns)
# Prepare the model
model = self.model_fn(train_dataset, eval_dataset)
# Execute the training
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(),
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
for e in range(1, self.epochs + 1):
for x, y, _ in train_dataset:
x_batch =[ for v in x.values()], dim=-1)
y_batch =[ for v in y.values()], dim=-1)
y_pred = model(x_batch)
loss = criterion(y_pred, y_batch)
acc = binary_acc(y_pred, y_batch)

We can now go ahead and use this step in our pipeline:

from zenml.pipelines import TrainingPipeline
from zenml.steps.split import RandomSplit
training_pipeline = TrainingPipeline()

An important note here: As you see from the code blocks that you see above, any input given to the constructor of a step will translate into an instance variable. So, when you want to use it you can use self, as we did with self.features.

What's next?

  • You can learn more about how the instance variables work in any step here. [WIP]

  • The next step within the TrainingPipeline is the Evaluator step.