Module core.steps.trainer.pytorch_trainers.torch_ff_trainer

Functions

binary_acc(y_pred, y_test) :

Classes

BinaryClassifier() : Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in
a tree structure. You can assign the submodules as regular attributes::

    import torch.nn as nn
    import torch.nn.functional as F

    class Model(nn.Module):
        def __init__(self):
            super(Model, self).__init__()
            self.conv1 = nn.Conv2d(1, 20, 5)
            self.conv2 = nn.Conv2d(20, 20, 5)

        def forward(self, x):
            x = F.relu(self.conv1(x))
            return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their
parameters converted too when you call :meth:`to`, etc.

:ivar training: Boolean represents whether this module is in training or
                evaluation mode.
:vartype training: bool

Initializes internal Module state, shared by both nn.Module and ScriptModule.

### Ancestors (in MRO)

* torch.nn.modules.module.Module

### Class variables

`dump_patches: bool`
:

`training: bool`
:

### Methods

`forward(self, inputs) ‑> Callable[..., Any]`
:   Defines the computation performed at every call.
    
    Should be overridden by all subclasses.
    
    .. note::
        Although the recipe for forward pass needs to be defined within
        this function, one should call the :class:`Module` instance afterwards
        instead of this since the former takes care of running the
        registered hooks while the latter silently ignores them.

FeedForwardTrainer(batch_size: int = 8, lr: float = 0.0001, epoch: int = 50, dropout_chance: int = 0.2, loss: str = 'mse', metrics: List[str] = None, hidden_layers: List[int] = None, hidden_activation: str = 'relu', last_activation: str = 'sigmoid', input_units: int = 8, output_units: int = 1, **kwargs) : Base class for all PyTorch based trainer steps. All pytorch based trainings should use this as the base class. An example is available with torch_ff_trainer.FeedForwardTrainer.

Constructor for the BaseTrainerStep. All subclasses used for custom
training of user machine learning models should implement the `run_fn`
`model_fn` and `input_fn` methods used for control flow, model training
and input data preparation, respectively.

Args:
    serving_model_dir: Directory indicating where to save the trained
     model.
    transform_output: Output of a preceding transform component.
    train_files: String, file pattern of the location of TFRecords for
     model training. Intended for use in the input_fn.
    eval_files: String, file pattern of the location of TFRecords for
     model evaluation. Intended for use in the input_fn.
    log_dir: Logs output directory.
    schema: Schema file from a preceding SchemaGen.

### Ancestors (in MRO)

* zenml.core.steps.trainer.pytorch_trainers.torch_base_trainer.TorchBaseTrainerStep
* zenml.core.steps.trainer.base_trainer.BaseTrainerStep
* zenml.core.steps.base_step.BaseStep

### Methods

`input_fn(self, file_pattern: List[str], tf_transform_output: tensorflow_transform.output_wrapper.TFTransformOutput)`
:   Class method for loading data from TFRecords saved to a location on
    disk. Override this method in subclasses to define your own custom
    data preparation flow.
    
    Args:
        file_pattern: File pattern matching saved TFRecords on disk.
        tf_transform_output: Output of the preceding Transform /
         Preprocessing component.
    
    Returns:
        dataset: A tf.data.Dataset constructed from the input file
         pattern and transform.

`model_fn(self, train_dataset, eval_dataset)`
:   Class method defining the training flow of the model. Override this
    in subclasses to define your own custom training flow.
    
    Args:
        train_dataset: tf.data.Dataset containing the training data.
        eval_dataset: tf.data.Dataset containing the evaluation data.
    
    Returns:
        model: A trained machine learning model.

`run_fn(self)`
:   Class method defining the control flow of the training process inside
    the TFX Trainer Component Executor. Override this method in subclasses
    to define your own custom training flow.