RNN Cell API

Warning

This package is currently experimental and may change in the near future.

Overview

The rnn module includes the recurrent neural network (RNN) cell APIs, a suite of tools for building an RNN’s symbolic graph.

Note

The rnn module offers higher-level interface while symbol.RNN is a lower-level interface. The cell APIs in rnn module are easier to use in most cases.

The rnn module

Cell interfaces

BaseRNNCell.__call__ Unroll the RNN for one time step.
BaseRNNCell.unroll Unroll an RNN cell across time steps.
BaseRNNCell.reset Reset before re-using the cell for another graph.
BaseRNNCell.begin_state Initial state for this cell.
BaseRNNCell.unpack_weights Unpack fused weight matrices into separate weight matrices.
BaseRNNCell.pack_weights Pack separate weight matrices into a single packed weight.

When working with the cell API, the precise input and output symbols depend on the type of RNN you are using. Take Long Short-Term Memory (LSTM) for example:

import mxnet as mx
# Shape of 'step_data' is (batch_size,).
step_input = mx.symbol.Variable('step_data')

# First we embed our raw input data to be used as LSTM's input.
embedded_step = mx.symbol.Embedding(data=step_input, \
                                    input_dim=input_dim, \
                                    output_dim=embed_dim)

# Then we create an LSTM cell.
lstm_cell = mx.rnn.LSTMCell(num_hidden=50)
# Initialize its hidden and memory states.
# 'begin_state' method takes an initialization function, and uses 'zeros' by default.
begin_state = lstm_cell.begin_state()

The LSTM cell and other non-fused RNN cells are callable. Calling the cell updates it’s state once. This transformation depends on both the current input and the previous states. See this blog post for a great introduction to LSTM and other RNN.

# Call the cell to get the output of one time step for a batch.
output, states = lstm_cell(embedded_step, begin_state)

# 'output' is lstm_t0_out_output of shape (batch_size, hidden_dim).

# 'states' has the recurrent states that will be carried over to the next step,
# which includes both the "hidden state" and the "cell state":
# Both 'lstm_t0_out_output' and 'lstm_t0_state_output' have shape (batch_size, hidden_dim).

Most of the time our goal is to process a sequence of many steps. For this, we need to unroll the LSTM according to the sequence length.

# Embed a sequence. 'seq_data' has the shape of (batch_size, sequence_length).
seq_input = mx.symbol.Variable('seq_data')
embedded_seq = mx.symbol.Embedding(data=seq_input, \
                                   input_dim=input_dim, \
                                   output_dim=embed_dim)

Note

Remember to reset the cell when unrolling/stepping for a new sequence by calling lstm_cell.reset().

# Note that when unrolling, if 'merge_outputs' is set to True, the 'outputs' is merged into a single symbol
# In the layout, 'N' represents batch size, 'T' represents sequence length, and 'C' represents the
# number of dimensions in hidden states.
outputs, states = lstm_cell.unroll(length=sequence_length, \
                                   inputs=embedded_seq, \
                                   layout='NTC', \
                                   merge_outputs=True)
# 'outputs' is concat0_output of shape (batch_size, sequence_length, hidden_dim).
# The hidden state and cell state from the final time step is returned:
# Both 'lstm_t4_out_output' and 'lstm_t4_state_output' have shape (batch_size, hidden_dim).

# If merge_outputs is set to False, a list of symbols for each of the time steps is returned.
outputs, states = lstm_cell.unroll(length=sequence_length, \
                                   inputs=embedded_seq, \
                                   layout='NTC', \
                                   merge_outputs=False)
# In this case, 'outputs' is a list of symbols. Each symbol is of shape (batch_size, hidden_dim).

Note

Loading and saving models that are built with RNN cells API requires using mx.rnn.load_rnn_checkpoint, mx.rnn.save_rnn_checkpoint, and mx.rnn.do_rnn_checkpoint. The list of all the used cells should be provided as the first argument to those functions.

Basic RNN cells

rnn module supports the following RNN cell types.

LSTMCell Long-Short Term Memory (LSTM) network cell.
GRUCell Gated Rectified Unit (GRU) network cell.
RNNCell Simple recurrent neural network cell.

Modifier cells

BidirectionalCell Bidirectional RNN cell.
DropoutCell Apply dropout on input.
ZoneoutCell Apply Zoneout on base cell.
ResidualCell Adds residual connection as described in Wu et al, 2016 (https://arxiv.org/abs/1609.08144).

A modifier cell takes in one or more cells and transforms the output of those cells. BidirectionalCell is one example. It takes two cells for forward unroll and backward unroll respectively. After unrolling, the outputs of the forward and backward pass are concatenated.

# Bidirectional cell takes two RNN cells, for forward and backward pass respectively.
# Having different types of cells for forward and backward unrolling is allowed.
bi_cell = mx.rnn.BidirectionalCell(
                 mx.rnn.LSTMCell(num_hidden=50),
                 mx.rnn.GRUCell(num_hidden=75))
outputs, states = bi_cell.unroll(length=sequence_length, \
                                 inputs=embedded_seq, \
                                 merge_outputs=True)
# The output feature is the concatenation of the forward and backward pass.
# Thus, the number of output dimensions is the sum of the dimensions of the two cells.
# 'outputs' is the symbol 'bi_out_output' of shape (batch_size, sequence_length, 125L)

# The states of the BidirectionalCell is a list of two lists, corresponding to the
# states of the forward and backward cells respectively.

Note

BidirectionalCell cannot be called or stepped, because the backward unroll requires the output of future steps, and thus the whole sequence is required.

Dropout and zoneout are popular regularization techniques that can be applied to RNN. rnn module provides DropoutCell and ZoneoutCell for regularization on the output and recurrent states of RNN. ZoneoutCell takes one RNN cell in the constructor, and supports unrolling like other cells.

zoneout_cell = mx.rnn.ZoneoutCell(lstm_cell, zoneout_states=0.5)
outputs, states = zoneout_cell.unroll(length=sequence_length, \
                                      inputs=embedded_seq, \
                                      merge_outputs=True)

DropoutCell performs dropout on the input sequence. It can be used in a stacked multi-layer RNN setting, which we will cover next.

Residual connection is a useful technique for training deep neural models because it helps the propagation of gradients by shortening the paths. ResidualCell provides such functionality for RNN models.

residual_cell = mx.rnn.ResidualCell(lstm_cell)
outputs, states = residual_cell.unroll(length=sequence_length, \
                                       inputs=embedded_seq, \
                                       merge_outputs=True)

The outputs are the element-wise sum of both the input and the output of the LSTM cell.

Multi-layer cells

SequentialRNNCell Sequantially stacking multiple RNN cells.
SequentialRNNCell.add Append a cell into the stack.

The SequentialRNNCell allows stacking multiple layers of RNN cells to improve the expressiveness and performance of the model. Cells can be added to a SequentialRNNCell in order, from bottom to top. When unrolling, the output of a lower-level cell is automatically passed to the cell above.

stacked_rnn_cells = mx.rnn.SequentialRNNCell()
stacked_rnn_cells.add(mx.rnn.BidirectionalCell(
                          mx.rnn.LSTMCell(num_hidden=50),
                          mx.rnn.LSTMCell(num_hidden=50)))

# Dropout the output of the bottom layer BidirectionalCell with a retention probability of 0.5.
stacked_rnn_cells.add(mx.rnn.DropoutCell(0.5))

stacked_rnn_cells.add(mx.rnn.LSTMCell(num_hidden=50))
outputs, states = stacked_rnn_cells.unroll(length=sequence_length, \
                                           inputs=embedded_seq, \
                                           merge_outputs=True)

# The output of SequentialRNNCell is the same as that of the last layer.
# In this case 'outputs' is the symbol 'concat6_output' of shape (batch_size, sequence_length, hidden_dim)
# The states of the SequentialRNNCell is a list of lists, with each list
# corresponding to the states of each of the added cells respectively.

Fused RNN cell

FusedRNNCell Fusing RNN layers across time step into one kernel.
FusedRNNCell.unfuse Unfuse the fused RNN in to a stack of rnn cells.

The computation of an RNN for an input sequence consists of many GEMM and point-wise operations with temporal dependencies dependencies. This could make the computation memory-bound especially on GPU, resulting in longer wall-time. By combining the computation of many small matrices into that of larger ones and streaming the computation whenever possible, the ratio of computation to memory I/O can be increased, which results in better performance on GPU. Such optimization technique is called “fusing”. This post talks in greater detail.

The rnn module includes a FusedRNNCell, which provides the optimized fused implementation. The FusedRNNCell supports bidirectional RNNs and dropout.

fused_lstm_cell = mx.rnn.FusedRNNCell(num_hidden=50, \
                                      num_layers=3, \
                                      mode='lstm', \
                                      bidirectional=True, \
                                      dropout=0.5)
outputs, _ = fused_lstm_cell.unroll(length=sequence_length, \
                                    inputs=embedded_seq, \
                                    merge_outputs=True)
# The 'outputs' is the symbol 'lstm_rnn_output' that has the shape
# (batch_size, sequence_length, forward_backward_concat_dim)

Note

FusedRNNCell supports GPU-only. It cannot be called or stepped.

Note

When dropout is set to non-zero in FusedRNNCell, the dropout is applied to the output of all layers except the last layer. If there is only one layer in the FusedRNNCell, the dropout rate is ignored.

Note

Similar to BidirectionalCell, when bidirectional flag is set to True, the output of FusedRNNCell is twice the size specified by num_hidden.

When training a deep, complex model on multiple GPUs it’s recommended to stack fused RNN cells (one layer per cell) together instead of one with all layers. The reason is that fused RNN cells don’t set gradients to be ready until the computation for the entire layer is completed. Breaking a multi-layer fused RNN cell into several one-layer ones allows gradients to be processed ealier. This reduces communication overhead, especially with multiple GPUs.

The unfuse() method can be used to convert the FusedRNNCell into an equivalent and CPU-compatible SequentialRNNCell that mirrors the settings of the FusedRNNCell.

unfused_lstm_cell = fused_lstm_cell.unfuse()
unfused_outputs, _ = unfused_lstm_cell.unroll(length=sequence_length, \
                                              inputs=embedded_seq, \
                                              merge_outputs=True)
# The 'outputs' is the symbol 'lstm_bi_l2_out_output' that has the shape
# (batch_size, sequence_length, forward_backward_concat_dim)

RNN checkpoint methods and parameters

save_rnn_checkpoint Save checkpoint for model using RNN cells.
load_rnn_checkpoint Load model checkpoint from file.
do_rnn_checkpoint Make a callback to checkpoint Module to prefix every epoch.
RNNParams Container for holding variables.
RNNParams.get Get the variable given a name if one exists or create a new one if missing.

The model parameters from the training with fused cell can be used for inference with unfused cell, and vice versa. As the parameters of fused and unfused cells are organized differently, they need to be converted first. FusedRNNCell‘s parameters are merged and flattened. In the fused example above, the mode has lstm_parameters of shape (total_num_params,), whereas the equivalent SequentialRNNCell’s parameters are separate:

'lstm_l0_i2h_weight': (out_dim, embed_dim)
'lstm_l0_i2h_bias': (out_dim,)
'lstm_l0_h2h_weight': (out_dim, hidden_dim)
'lstm_l0_h2h_bias': (out_dim,)
'lstm_r0_i2h_weight': (out_dim, embed_dim)
...

All cells in the rnn module support the method unpack_weights() for converting FusedRNNCell parameters to the unfused format and pack_weights() for fusing the parameters. The RNN-specific checkpointing methods (load_rnn_checkpoint, save_rnn_checkpoint, do_rnn_checkpoint) handle the conversion transparently based on the provided cells.

I/O utilities

BucketSentenceIter Simple bucketing iterator for language model.
encode_sentences Encode sentences and (optionally) build a mapping from string tokens to integer indices.

API Reference

class mxnet.rnn.BaseRNNCell(prefix='', params=None)

Abstract base class for RNN cells

Parameters:
  • prefix (str, optional) – Prefix for names of layers (this prefix is also used for names of weights if params is None i.e. if params are being created and not reused)
  • params (RNNParams, default None.) – Container for weight sharing between cells. A new RNNParams container is created if params is None.
__call__(inputs, states)

Unroll the RNN for one time step.

Parameters:
  • inputs (sym.Variable) – input symbol, 2D, batch * num_units
  • states (list of sym.Variable) – RNN state from previous step or the output of begin_state().
Returns:

  • output (Symbol) – Symbol corresponding to the output from the RNN when unrolling for a single time step.
  • states (nested list of Symbol) – The new state of this RNN after this unrolling. The type of this symbol is same as the output of begin_state(). This can be used as input state to the next time step of this RNN.

See also

begin_state()
This function can provide the states for the first time step.
unroll()
This function unrolls an RNN for a given number of (>=1) time steps.
reset()

Reset before re-using the cell for another graph.

__call__(inputs, states)

Unroll the RNN for one time step.

Parameters:
  • inputs (sym.Variable) – input symbol, 2D, batch * num_units
  • states (list of sym.Variable) – RNN state from previous step or the output of begin_state().
Returns:

  • output (Symbol) – Symbol corresponding to the output from the RNN when unrolling for a single time step.
  • states (nested list of Symbol) – The new state of this RNN after this unrolling. The type of this symbol is same as the output of begin_state(). This can be used as input state to the next time step of this RNN.

See also

begin_state()
This function can provide the states for the first time step.
unroll()
This function unrolls an RNN for a given number of (>=1) time steps.
params

Parameters of this cell

state_info

shape and layout information of states

state_shape

shape(s) of states

begin_state(func=<function zeros>, **kwargs)

Initial state for this cell.

Parameters:
  • func (callable, default symbol.zeros) – Function for creating initial state. Can be symbol.zeros, symbol.uniform, symbol.Variable etc. Use symbol.Variable if you want to directly feed input as states.
  • **kwargs

    more keyword arguments passed to func. For example mean, std, dtype, etc.

Returns:

states – Starting states for the first RNN step.

Return type:

nested list of Symbol

unpack_weights(args)

Unpack fused weight matrices into separate weight matrices.

For example, say you use a module object mod to run a network that has an lstm cell. In mod.get_params()[0], the lstm parameters are all represented as a single big vector. cell.unpack_weights(mod.get_params()[0]) will unpack this vector into a dictionary of more readable lstm parameters - c, f, i, o gates for i2h (input to hidden) and h2h (hidden to hidden) weights.

Parameters:args (dict of str -> NDArray) – Dictionary containing packed weights. usually from Module.get_params()[0].
Returns:args – Dictionary with unpacked weights associated with this cell.
Return type:dict of str -> NDArray

See also

pack_weights()
Performs the reverse operation of this function.
pack_weights(args)

Pack separate weight matrices into a single packed weight.

Parameters:args (dict of str -> NDArray) – Dictionary containing unpacked weights.
Returns:args – Dictionary with packed weights associated with this cell.
Return type:dict of str -> NDArray
unroll(length, inputs, begin_state=None, layout='NTC', merge_outputs=None)

Unroll an RNN cell across time steps.

Parameters:
  • length (int) – Number of steps to unroll.
  • inputs (Symbol, list of Symbol, or None) –

    If inputs is a single Symbol (usually the output of Embedding symbol), it should have shape (batch_size, length, ...) if layout == ‘NTC’, or (length, batch_size, ...) if layout == ‘TNC’.

    If inputs is a list of symbols (usually output of previous unroll), they should all have shape (batch_size, ...).

  • begin_state (nested list of Symbol, default None) – Input states created by begin_state() or output state of another cell. Created from begin_state() if None.
  • layout (str, optional) – layout of input symbol. Only used if inputs is a single Symbol.
  • merge_outputs (bool, optional) – If False, return outputs as a list of Symbols. If True, concatenate output across time steps and return a single symbol with shape (batch_size, length, ...) if layout == ‘NTC’, or (length, batch_size, ...) if layout == ‘TNC’. If None, output whatever is faster.
Returns:

  • outputs (list of Symbol or Symbol) – Symbol (if merge_outputs is True) or list of Symbols (if merge_outputs is False) corresponding to the output from the RNN from this unrolling.
  • states (nested list of Symbol) – The new state of this RNN after this unrolling. The type of this symbol is same as the output of begin_state().

class mxnet.rnn.LSTMCell(num_hidden, prefix='lstm_', params=None, forget_bias=1.0)

Long-Short Term Memory (LSTM) network cell.

Parameters:
  • num_hidden (int) – Number of units in output symbol.
  • prefix (str, default ‘lstm_‘) – Prefix for name of layers (and name of weight if params is None).
  • params (RNNParams, default None) – Container for weight sharing between cells. Created if None.
  • forget_bias (bias added to forget gate, default 1.0.) – Jozefowicz et al. 2015 recommends setting this to 1.0
class mxnet.rnn.GRUCell(num_hidden, prefix='gru_', params=None)

Gated Rectified Unit (GRU) network cell. Note: this is an implementation of the cuDNN version of GRUs (slight modification compared to Cho et al. 2014).

Parameters:
  • num_hidden (int) – Number of units in output symbol.
  • prefix (str, default ‘gru_‘) – Prefix for name of layers (and name of weight if params is None).
  • params (RNNParams, default None) – Container for weight sharing between cells. Created if None.
class mxnet.rnn.RNNCell(num_hidden, activation='tanh', prefix='rnn_', params=None)

Simple recurrent neural network cell.

Parameters:
  • num_hidden (int) – Number of units in output symbol.
  • activation (str or Symbol, default 'tanh') – Type of activation function. Options are ‘relu’ and ‘tanh’.
  • prefix (str, default ‘rnn_‘) – Prefix for name of layers (and name of weight if params is None).
  • params (RNNParams, default None) – Container for weight sharing between cells. Created if None.
class mxnet.rnn.FusedRNNCell(num_hidden, num_layers=1, mode='lstm', bidirectional=False, dropout=0.0, get_next_state=False, forget_bias=1.0, prefix=None, params=None)

Fusing RNN layers across time step into one kernel. Improves speed but is less flexible. Currently only supported if using cuDNN on GPU.

Parameters:
  • num_hidden (int) – Number of units in output symbol.
  • num_layers (int, default 1) – Number of layers in the cell.
  • mode (str, default 'lstm') – Type of RNN. options are ‘rnn_relu’, ‘rnn_tanh’, ‘lstm’, ‘gru’.
  • bidirectional (bool, default False) – Whether to use bidirectional unroll. The output dimension size is doubled if bidrectional.
  • dropout (float, default 0.) – Fraction of the input that gets dropped out during training time.
  • get_next_state (bool, default False) – Whether to return the states that can be used as starting states next time.
  • forget_bias (bias added to forget gate, default 1.0.) – Jozefowicz et al. 2015 recommends setting this to 1.0
  • prefix (str, default ‘$mode_’ such as ‘lstm_‘) – Prefix for names of layers (this prefix is also used for names of weights if params is None i.e. if params are being created and not reused)
  • params (RNNParams, default None) – Container for weight sharing between cells. Created if None.
unfuse()

Unfuse the fused RNN in to a stack of rnn cells.

Returns:cell – unfused cell that can be used for stepping, and can run on CPU.
Return type:SequentialRNNCell
class mxnet.rnn.SequentialRNNCell(params=None)

Sequantially stacking multiple RNN cells.

Parameters:params (RNNParams, default None) – Container for weight sharing between cells. Created if None.
add(cell)

Append a cell into the stack.

Parameters:cell (BaseRNNCell) – The cell to be appended. During unroll, previous cell’s output (or raw inputs if no previous cell) is used as the input to this cell.
class mxnet.rnn.BidirectionalCell(l_cell, r_cell, params=None, output_prefix='bi_')

Bidirectional RNN cell.

Parameters:
  • l_cell (BaseRNNCell) – cell for forward unrolling
  • r_cell (BaseRNNCell) – cell for backward unrolling
  • params (RNNParams, default None.) – Container for weight sharing between cells. A new RNNParams container is created if params is None.
  • output_prefix (str, default ‘bi_‘) – prefix for name of output
class mxnet.rnn.DropoutCell(dropout, prefix='dropout_', params=None)

Apply dropout on input.

Parameters:
  • dropout (float) – Percentage of elements to drop out, which is 1 - percentage to retain.
  • prefix (str, default ‘dropout_‘) – Prefix for names of layers (this prefix is also used for names of weights if params is None i.e. if params are being created and not reused)
  • params (RNNParams, default None) – Container for weight sharing between cells. Created if None.
class mxnet.rnn.ZoneoutCell(base_cell, zoneout_outputs=0.0, zoneout_states=0.0)

Apply Zoneout on base cell.

Parameters:
  • base_cell (BaseRNNCell) – Cell on whose states to perform zoneout.
  • zoneout_outputs (float, default 0.) – Fraction of the output that gets dropped out during training time.
  • zoneout_states (float, default 0.) – Fraction of the states that gets dropped out during training time.
class mxnet.rnn.ResidualCell(base_cell)

Adds residual connection as described in Wu et al, 2016 (https://arxiv.org/abs/1609.08144).

Output of the cell is output of the base cell plus input.

Parameters:base_cell (BaseRNNCell) – Cell on whose outputs to add residual connection.
class mxnet.rnn.RNNParams(prefix='')

Container for holding variables. Used by RNN cells for parameter sharing between cells.

Parameters:prefix (str) – Names of all variables created by this container will be prepended with prefix.
get(name, **kwargs)

Get the variable given a name if one exists or create a new one if missing.

Parameters:
  • name (str) – name of the variable
  • **kwargs

    more arguments that’s passed to symbol.Variable

class mxnet.rnn.BucketSentenceIter(sentences, batch_size, buckets=None, invalid_label=-1, data_name='data', label_name='softmax_label', dtype='float32', layout='NT')

Simple bucketing iterator for language model. The label at each sequence step is the following token in the sequence.

Parameters:
  • sentences (list of list of int) – Encoded sentences.
  • batch_size (int) – Batch size of the data.
  • invalid_label (int, optional) – Key for invalid label, e.g. <end-of-sentence>. The default is -1.
  • dtype (str, optional) – Data type of the encoding. The default data type is ‘float32’.
  • buckets (list of int, optional) – Size of the data buckets. Automatically generated if None.
  • data_name (str, optional) – Name of the data. The default name is ‘data’.
  • label_name (str, optional) – Name of the label. The default name is ‘softmax_label’.
  • layout (str, optional) – Format of data and label. ‘NT’ means (batch_size, length) and ‘TN’ means (length, batch_size).
reset()

Resets the iterator to the beginning of the data.

next()

Returns the next batch of data.

rnn.encode_sentences(sentences, vocab=None, invalid_label=-1, invalid_key='\n', start_label=0)

Encode sentences and (optionally) build a mapping from string tokens to integer indices. Unknown keys will be added to vocabulary.

Parameters:
  • sentences (list of list of str) – A list of sentences to encode. Each sentence should be a list of string tokens.
  • vocab (None or dict of str -> int) – Optional input Vocabulary
  • invalid_label (int, default -1) – Index for invalid token, like <end-of-sentence>
  • invalid_key (str, default 'n') – Key for invalid token. Use ‘n’ for end of sentence by default.
  • start_label (int) – lowest index.
Returns:

  • result (list of list of int) – encoded sentences
  • vocab (dict of str -> int) – result vocabulary

rnn.save_rnn_checkpoint(cells, prefix, epoch, symbol, arg_params, aux_params)

Save checkpoint for model using RNN cells. Unpacks weight before saving.

Parameters:
  • cells (RNNCell or list of RNNCells) – The RNN cells used by this symbol.
  • prefix (str) – Prefix of model name.
  • epoch (int) – The epoch number of the model.
  • symbol (Symbol) – The input symbol
  • arg_params (dict of str to NDArray) – Model parameter, dict of name to NDArray of net’s weights.
  • aux_params (dict of str to NDArray) – Model parameter, dict of name to NDArray of net’s auxiliary states.

Notes

  • prefix-symbol.json will be saved for symbol.
  • prefix-epoch.params will be saved for parameters.
rnn.load_rnn_checkpoint(cells, prefix, epoch)

Load model checkpoint from file. Pack weights after loading.

Parameters:
  • cells (RNNCell or list of RNNCells) – The RNN cells used by this symbol.
  • prefix (str) – Prefix of model name.
  • epoch (int) – Epoch number of model we would like to load.
Returns:

  • symbol (Symbol) – The symbol configuration of computation network.
  • arg_params (dict of str to NDArray) – Model parameter, dict of name to NDArray of net’s weights.
  • aux_params (dict of str to NDArray) – Model parameter, dict of name to NDArray of net’s auxiliary states.

Notes

  • symbol will be loaded from prefix-symbol.json.
  • parameters will be loaded from prefix-epoch.params.
rnn.do_rnn_checkpoint(cells, prefix, period=1)

Make a callback to checkpoint Module to prefix every epoch. unpacks weights used by cells before saving.

Parameters:
  • cells (RNNCell or list of RNNCells) – The RNN cells used by this symbol.
  • prefix (str) – The file prefix to checkpoint to
  • period (int) – How many epochs to wait before checkpointing. Default is 1.
Returns:

callback – The callback function that can be passed as iter_end_callback to fit.

Return type:

function