Module core.steps.trainer.pytorch_trainers.torch_base_trainer¶
Classes¶
TorchBaseTrainerStep(serving_model_dir: str = None, transform_output: str = None, train_files=None, eval_files=None, **kwargs)
: Base class for all PyTorch based trainer steps. All pytorch based
trainings should use this as the base class. An example is available
with torch_ff_trainer.FeedForwardTrainer.
Constructor for the BaseTrainerStep. All subclasses used for custom
training of user machine learning models should implement the `run_fn`
`model_fn` and `input_fn` methods used for control flow, model training
and input data preparation, respectively.
Args:
serving_model_dir: Directory indicating where to save the trained
model.
transform_output: Output of a preceding transform component.
train_files: String, file pattern of the location of TFRecords for
model training. Intended for use in the input_fn.
eval_files: String, file pattern of the location of TFRecords for
model evaluation. Intended for use in the input_fn.
log_dir: Logs output directory.
schema: Schema file from a preceding SchemaGen.
### Ancestors (in MRO)
* zenml.core.steps.trainer.base_trainer.BaseTrainerStep
* zenml.core.steps.base_step.BaseStep