Bucketing in MXNet¶
When we train recurrent neural networks (RNNs), we unroll the network in time. For a single example of length T, we would unroll the network T steps. In the unrolled view, the network is acyclic and thus we can think of it as a feedforward neural network in which weights are tied across adjacent time steps. The unrolled view allows us to train the network via the traditional backpropagation algorithm - in this case, we call the algorithm back-propagation through time.
Things get complicated when we work with datasets where the example sequences have varying lengths. In the recurrent view, each example can pass through the same architecture. But in the unrolled view, each example requires a different number of unrollings, and thus corresponds to a slightly different feedforward network. If we train on one example at a time, we can simply unroll the desired amount on each training iteration. But things get more complicated when we try to perform mini-batch training. A naive approach might be to pad all the sequences so that they appear to have the length of the longest example. However, this could be wasteful because on shorter sequences, most of the computations are done on padded data.
Borrowed from TensorFlow’s sequence training code, bucketing offers an effective solution to make minibatches out of varying-length sequences. Instead of unrolling the network to the maximum possible sequence length, we unroll multiple instances of different lengths (e.g., length 5, 10, 20, 30). During training, we use the most appropriate unrolled model for each mini-batch of data. Although the models with different numbers of unrollings have different feedforward architectures, their parameters are shared in time. MXNet reuses the internal memory buffers among all executors.
For simple RNNs, you can use a for loop to explicitly go over the input sequences and perform a backpropagation-through-time by maintaining the connection of the states and gradients through time. However, this implementation approach results in slow processing. This approach works with variable length sequences. For more complicated models (e.g., translation that uses a sequence-to-sequence model), explicitly unrolling is the easiest way. In this example, we introduce the MXNet APIs that allows us to implement bucketing.
Variable-length Sequence Training for PTB¶
In this example, we use a simple architecture consisting of a word-embedding layer followed by two LSTM layers. In the original example, the model is unrolled explicitly in time for a fixed length of 32.To enable bucketing, MXNet needs to know how to construct a new unrolled symbolic architecture for a different sequence length. To achieve this, instead of constructing a model with a fixed `Symbol`, we use a callback function that generates a new `Symbol` on a *bucket key*.
model = mx.mod.BucketingModule( sym_gen = sym_gen, default_bucket_key = data_train.default_bucket_key, context = contexts)
sym_gen must be a function that takes one argument,
bucket_key, and returns a
Symbol for this bucket. We’ll use the sequence length as the bucket key. A bucket key could be anything. For example, in neural translation, because different combinations of input-output sequence lengths correspond to different unrolling, the bucket key could be a pair of lengths.
def sym_gen(seq_len): return lstm_unroll(num_lstm_layer, seq_len, len(vocab), num_hidden=num_hidden, num_embed=num_embed, num_label=len(vocab))
The data iterator needs to report the
default_bucket_key, which allows MXNet to do some parameter initialization before reading the data. Now the model is capable of training with different buckets by sharing the parameters and intermediate computation buffers between bucket executors.
To train, we still need to add some extra bits to our
DataIter. Apart from reporting the
default_bucket_key as mentioned previously, we also need to report the current
bucket_key for each mini-batch. More specifically, the
DataBatch object returned in each mini-batch by the
DataIter should contain the following additional properties:
bucket_key: The bucket key that corresponds to this batch of data. In our example, it is the sequence length for this batch of data. If the executors corresponding to this bucket key have not yet been created, they will be constructed according to the symbolic model returned by
gen_symon this bucket key. The executors will be cached for future use. Note that generated
Symbols could be arbitrary, but they should all have the same trainable parameters and auxiliary states.
provide_data: The same information reported by the
DataIterobject. Because each bucket now corresponds to a different architecture, they could have different input data. Also, make sure that the
provide_datainformation returned by the
DataIterobject is compatible with the architecture for
provide_label: The same as
DataIter is responsible for grouping the data into different buckets.
Assuming that randomization is enabled, at each iteration,
DataIter chooses a random bucket (according to a distribution balanced by the bucket sizes),
and then randomly chooses sequences from that bucket to form a mini-batch.
It also applies padding for sequences of different length within the mini-batch as necessary.
For a full, working implementation of a
that reads text sequences by as described above, see example/rnn/lstm_ptb_bucketing.py.
In this example, you can use bucketing with a static configuration (e.g.,
buckets = [10, 20, 30, 40, 50, 60]), or let MXNet generate buckets automatically according to the characteristics of the dataset (
buckets = ). The latter approach is implemented by adding a bucket as long as the number of sequences assigned to that bucket is exceeds some minimum count. For more information, see default_gen_buckets().
Beyond Sequence Training¶
In this example, we briefly explained how the bucketing API works.
However, the API is not limited to bucketing by sequence lengths.
The bucket key can be an arbitrary object, as long
as the architecture returned by
is compatible with (has the same set of parameters) as the object.