Module API

The module API provides an intermediate and high-level interface for performing computation with neural networks in MXNet. A module is an instance of subclasses of the BaseModule. The most widely used module class is called Module. Module wraps a Symbol and one or more Executors. For a full list of functions, see BaseModule. A subclass of modules might have extra interface functions. This topic provides some examples of common use cases. All of the module APIs are in the Module namespace.

Preparing a Module for Computation

To construct a module, refer to the constructors for the module class. For example, the Module class accepts a Symbol as input:

    import ml.dmlc.mxnet._
    import ml.dmlc.mxnet.module.{FitParams, Module}

    // construct a simple MLP
    val data = Symbol.Variable("data")
    val fc1 = Symbol.FullyConnected(name = "fc1")(data)(Map("num_hidden" -> 128))
    val act1 = Symbol.Activation(name = "relu1")(fc1)(Map("act_type" -> "relu"))
    val fc2 = Symbol.FullyConnected(name = "fc2")(act1)(Map("num_hidden" -> 64))
    val act2 = Symbol.Activation(name = "relu2")(fc2)(Map("act_type" -> "relu"))
    val fc3 = Symbol.FullyConnected(name = "fc3")(act2)(Map("num_hidden" -> 10))
    val out = Symbol.SoftmaxOutput(name = "softmax")(fc3)()

    // construct the module
    val mod = new Module(out)

By default, context is the CPU. If you need data parallelization, you can specify a GPU context or an array of GPU contexts.

Before you can compute with a module, you need to call bind() to allocate the device memory and initParams() or SetParams() to initialize the parameters. If you simply want to fit a module, you don’t need to call bind() and initParams() explicitly, because the fit() function automatically calls them if they are needed.

    mod.bind(dataShapes = train_dataiter.provideData, labelShapes = Some(train_dataiter.provideLabel))

Now you can compute with the module using functions like forward(), backward(), etc.

Training, Predicting, and Evaluating

Modules provide high-level APIs for training, predicting, and evaluating. To fit a module, call the fit() function with some DataIters:

    import ml.dmlc.mxnet.optimizer.SGD
    val mod = new Module(softmax), evalData = scala.Option(eval_dataiter), \
    numEpoch = n_epoch, fitParams = new FitParams()\
    .setOptimizer(new SGD(learningRate = 0.1f, momentum = 0.9f, wd = 0.0001f)))

The interface is very similar to the old FeedForward class. You can pass in batch-end callbacks using setBatchEndCallback and epoch-end callbacks using setEpochEndCallback. You can also set parameters using methods like setOptimizer and setEvalMetric. To learn more about the FitParams(), see the API page. To predict with a module, call predict() with a DataIter:


The module collects and returns all of the prediction results. For more details about the format of the return values, see the documentation for the predict() function.

When prediction results might be too large to fit in memory, use the predictEveryBatch API:

    val preds = mod.predictEveryBatch(val_dataiter)
    var i = 0
    while (val_dataiter.hasNext) {
       val batch =
       val predLabel: Array[Int] = NDArray.argmax_channel(preds(i)(0))
       val label = batch.label(0)
       //do something...
       i += 1

If you need to evaluate on a test set and don’t need the prediction output, call the score() function with a DataIter and an EvalMetric:

    mod.score(val_dataiter, metric)

This runs predictions on each batch in the provided DataIter and computes the evaluation score using the provided EvalMetric. The evaluation results are stored in metric so that you can query later.

Saving and Loading Module Parameters

To save the module parameters in each training epoch, use a checkpoint callback:

    val modelPrefix: String = "mymodel"

    for (epoch <- 0 until 5) {
          // forward backward pass
         //do something...
        val checkpoint = mod.saveCheckpoint(modelPrefix, epoch, saveOptStates = true)


To load the saved module parameters, call the loadCheckpoint function:

    val mod = Module.loadCheckpoint(modelPrefix, loadModelEpoch, loadOptimizerStates = true)

To initialize parameters, Bind the symbols to construct executors first with bind method. Then, initialize the parameters and auxiliary states by calling initParams() method.

    mod.bind(dataShapes = train_dataiter.provideData, labelShapes = Some(train_dataiter.provideLabel))

To get current parameters, use getParams method.

    val (argParams, auxParams) = mod.getParams

To assign parameter and aux state values, use setParams method.

    mod.setParams(argParams, auxParams)

To resume training from a saved checkpoint, instead of calling setParams(), directly call fit(), passing the loaded parameters, so that fit() knows to start from those parameters instead of initializing randomly:, fitParams=new FitParams().setArgParams(argParams).\

Create an object of the FitParams() class, and then use it to call the setBeginEpoch() method to pass beginEpoch so that fit() knows to resume from a saved epoch.

Next Steps

  • See Model API for an alternative simple high-level interface for training neural networks.
  • See Symbolic API for operations on NDArrays that assemble neural networks from layers.
  • See IO Data Loading API for parsing and loading data.
  • See NDArray API for vector/matrix/tensor operations.
  • See KVStore API for multi-GPU and multi-host distributed training.