Using Baidu Warp-CTC with MXNet¶
Baidu-WarpCTC is a CTC implementation by Baidu that supports using GPU processors. It supports using CTC with LSTM to solve label alignment problems in many areas, such as OCR and speech recognition.
You can get the source code for the example on GitHub.
Install Baidu Warp-CTC¶
cd ~/ git clone https://github.com/baidu-research/warp-ctc cd warp-ctc mkdir build cd build cmake .. make sudo make install
Enable Warp-CTC in MXNet¶
comment out following lines in make/config.mk WARPCTC_PATH = $(HOME)/warp-ctc MXNET_PLUGINS += plugin/warpctc/warpctc.mk rebuild mxnet by make clean && make -j4
There are two examples. One is a toy example that validates CTC integration. The second is an OCR example with LSTM and CTC. You can run it by typing the following code:
cd examples/warpctc python lstm_ocr.py
The OCR example is constructed as follows:
- It generates a 80x30-pixel image for a 4-digit captcha using a Python captcha library.
- The 80x30 image is used as 80 input for lstm, and every input is one column of the image (a 30 dim vector).
- The output layer use CTC loss.
The following code shows detailed construction of the net:
def lstm_unroll(num_lstm_layer, seq_len, num_hidden, num_label): param_cells =  last_states =  for i in range(num_lstm_layer): param_cells.append(LSTMParam(i2h_weight=mx.sym.Variable("l%d_i2h_weight" % i), i2h_bias=mx.sym.Variable("l%d_i2h_bias" % i), h2h_weight=mx.sym.Variable("l%d_h2h_weight" % i), h2h_bias=mx.sym.Variable("l%d_h2h_bias" % i))) state = LSTMState(c=mx.sym.Variable("l%d_init_c" % i), h=mx.sym.Variable("l%d_init_h" % i)) last_states.append(state) assert(len(last_states) == num_lstm_layer) data = mx.sym.Variable('data') label = mx.sym.Variable('label') #every column of image is an input, there are seq_len inputs wordvec = mx.sym.SliceChannel(data=data, num_outputs=seq_len, squeeze_axis=1) hidden_all =  for seqidx in range(seq_len): hidden = wordvec[seqidx] for i in range(num_lstm_layer): next_state = lstm(num_hidden, indata=hidden, prev_state=last_states[i], param=param_cells[i], seqidx=seqidx, layeridx=i) hidden = next_state.h last_states[i] = next_state hidden_all.append(hidden) hidden_concat = mx.sym.Concat(*hidden_all, dim=0) pred = mx.sym.FullyConnected(data=hidden_concat, num_hidden=11) # here we do NOT need to transpose label as other lstm examples do label = mx.sym.Reshape(data=label, target_shape=(0,)) #label should be int type, so use cast label = mx.sym.Cast(data = label, dtype = 'int32') sm = mx.sym.WarpCTC(data=pred, label=label, label_length = num_label, input_length = seq_len) return sm
Supporting Multi-label Length¶
Provide labels with length b. For samples whose label length is smaller than b, append 0 to the label data to make it have length b.
0 is reserved for a blank label.