Module core.steps.trainer.tensorflow_trainers.tf_ff_trainer

Classes

FeedForwardTrainer(batch_size: int = 8, lr: float = 0.001, epochs: int = 1, dropout_chance: int = 0.2, loss: str = 'mse', metrics: List[str] = None, hidden_layers: List[int] = None, hidden_activation: str = 'relu', last_activation: str = 'sigmoid', output_units: int = 1, **kwargs) : Basic Feedforward Neural Network trainer. This step serves as an example of how to define your training logic to integrate well with TFX and Tensorflow Serving.

Basic feedforward neural network constructor.

Args:
    batch_size: Input data batch size.
    lr: Learning rate of the optimizer.
    epochs: Number of training epochs (whole passes through the data).
    dropout_chance: Dropout chance, i.e. probability of a neuron not
     propagating its weights into the prediction at a dropout layer.
    loss: Name of the loss function to use.
    metrics: List of metrics to log in training.
    hidden_layers: List of sizes to use in consecutive hidden layers.
     Length determines the number of hidden layers in the network.
    hidden_activation: Name of the activation function to use in the
     hidden layers.
    last_activation: Name of the final activation function creating
     the class probability distribution.
    output_units: Number of output units, corresponding to the number
     of classes.
    **kwargs: Additional keyword arguments.

### Ancestors (in MRO)

* zenml.core.steps.trainer.tensorflow_trainers.tf_base_trainer.TFBaseTrainerStep
* zenml.core.steps.trainer.base_trainer.BaseTrainerStep
* zenml.core.steps.base_step.BaseStep

### Methods

`get_run_fn(self)`
:

`input_fn(self, file_pattern: List[str], tf_transform_output: tensorflow_transform.output_wrapper.TFTransformOutput)`
:   Feedforward input_fn for loading data from TFRecords saved to a
    location on disk.
    
    Args:
        file_pattern: File pattern matching saved TFRecords on disk.
        tf_transform_output: Output of the preceding Transform /
         Preprocessing component.
    
    Returns:
        dataset: tf.data.Dataset created out of the input files.

`model_fn(self, train_dataset: tensorflow.python.data.ops.dataset_ops.DatasetV2, eval_dataset: tensorflow.python.data.ops.dataset_ops.DatasetV2)`
:   Function defining the training flow of the feedforward neural network
    model.
    
    Args:
        train_dataset: tf.data.Dataset containing the training data.
        eval_dataset: tf.data.Dataset containing the evaluation data.
    
    Returns:
        model: A trained feedforward neural network model.

`run_fn(self)`
:   Class method defining the control flow of the training process inside
    the TFX Trainer Component Executor. Override this method in subclasses
    to define your own custom training flow.