Module core.steps.trainer.base_trainer

Classes

BaseTrainerStep(serving_model_dir: str = None, transform_output: str = None, train_files=None, eval_files=None, **kwargs) : Base step interface for all Trainer steps. All of your code concerning model training should leverage subclasses of this class.

Constructor for the BaseTrainerStep. All subclasses used for custom
training of user machine learning models should implement the `run_fn`
`model_fn` and `input_fn` methods used for control flow, model training
and input data preparation, respectively.

Args:
    serving_model_dir: Directory indicating where to save the trained
     model.
    transform_output: Output of a preceding transform component.
    train_files: String, file pattern of the location of TFRecords for
     model training. Intended for use in the input_fn.
    eval_files: String, file pattern of the location of TFRecords for
     model evaluation. Intended for use in the input_fn.
    log_dir: Logs output directory.
    schema: Schema file from a preceding SchemaGen.

### Ancestors (in MRO)

* zenml.core.steps.base_step.BaseStep

### Class variables

`STEP_TYPE`
:

### Static methods

`model_fn(train_dataset: tensorflow.python.data.ops.dataset_ops.DatasetV2, eval_dataset: tensorflow.python.data.ops.dataset_ops.DatasetV2)`
:   Class method defining the training flow of the model. Override this
    in subclasses to define your own custom training flow.
    
    Args:
        train_dataset: tf.data.Dataset containing the training data.
        eval_dataset: tf.data.Dataset containing the evaluation data.
    
    Returns:
        model: A trained machine learning model.

### Methods

`get_run_fn(self)`
:

`input_fn(self, file_pattern: List[str], tf_transform_output: tensorflow_transform.output_wrapper.TFTransformOutput)`
:   Class method for loading data from TFRecords saved to a location on
    disk. Override this method in subclasses to define your own custom
    data preparation flow.
    
    Args:
        file_pattern: File pattern matching saved TFRecords on disk.
        tf_transform_output: Output of the preceding Transform /
         Preprocessing component.
    
    Returns:
        dataset: A tf.data.Dataset constructed from the input file
         pattern and transform.

`run_fn(self)`
:   Class method defining the control flow of the training process inside
    the TFX Trainer Component Executor. Override this method in subclasses
    to define your own custom training flow.