MXNet Scala Model API

The model API provides a simplified way to train neural networks using common best practices. It’s a thin wrapper built on top of the ndarray and symbolic modules that make neural network training easy.


Train the Model

To train a model, perform two steps: configure the model using the symbol parameter, then call model.Feedforward.create to create the model. The following example creates a two-layer neural network.

    // configure a two layer neuralnetwork
    val data = Symbol.Variable("data")
    val fc1 = Symbol.FullyConnected(name = "fc1")()(Map("data" -> data, "num_hidden" -> 128))
    val act1 = Symbol.Activation(name = "relu1")()(Map("data" -> fc1, "act_type" -> "relu"))
    val fc2 = Symbol.FullyConnected(name = "fc2")()(Map("data" -> act1, "num_hidden" -> 64))
    val softmax = Symbol.SoftmaxOutput(name = "sm")()(Map("data" -> fc2))

    // Construct the FeedForward model and fit on the input training data
    val model = FeedForward.newBuilder(softmax)
      .setOptimizer(new SGD(learningRate = 0.01f, momentum = 0.9f, wd = 0.0001f))

You can also use the scikit-learn-style construct and fit function to create a model.

    // create a model using sklearn-style two-step way
    val model = new FeedForward(softmax,
                                numEpoch = numEpochs,
                                argParams = argParams,
                                auxParams = auxParams,
                                beginEpoch = beginEpoch,
                                epochSize = epochSize) = train)

For more information, see API Reference.

Save the Model

After the job is done, save your work. We also provide save and load functions. You can use the load function to load a model checkpoint from a file.

    // checkpoint the model data into file,
    // save a model to modelPrefix-symbol.json and modelPrefix-0100.params
    val modelPrefix: String = "checkpt"
    val num_epoch = 100
    Model.saveCheckpoint(modelPrefix, epoch + 1, symbol, argParams, auxStates)

    // load model back
    val model_loaded = FeedForward.load(modelPrefix, num_epoch)

The advantage of these two save and load functions is that they are language agnostic. You should be able to save and load directly into cloud storage, such as Amazon S3 and HDFS.

Periodic Checkpointing

We recommend checkpointing your model after each iteration. To do this, use EpochEndCallback to add a Model.saveCheckpoint(<parameters>) checkpoint callback to the function after each iteration .

    // modelPrefix-symbol.json will be saved for symbol.
    // modelPrefix-epoch.params will be saved for parameters.
    // Checkpoint the model into file. Can specify parameters.
    // For more information, check API doc.
    val modelPrefix: String = "checkpt"
    val checkpoint: EpochEndCallback =
    if (modelPrefix == null) null
    else new EpochEndCallback {
      override def invoke(epoch: Int, symbol: Symbol,
                         argParams: Map[String, NDArray],
                         auxStates: Map[String, NDArray]): Unit = {
       Model.saveCheckpoint(modelPrefix, epoch + 1, symbol, argParams, auxParams)

    // Load model checkpoint from file. Returns symbol, argParams, auxParams.
    val (_, argParams, _) = Model.loadCheckpoint(modelPrefix, num_epoch)

You can load the model checkpoint later using Model.loadCheckpoint(modelPrefix, num_epoch).

Use Multiple Devices

Set ctx to the list of devices that you want to train on. You can create a list of devices in any way you want.

    val devices = Array(Context.gpu(0), Context.gpu(1))

    val model = new FeedForward(ctx = devices,
             symbol = network,
             numEpoch = numEpochs,
             optimizer = optimizer,
             epochSize = epochSize,

Training occurs in parallel on the GPUs that you specify.

Next Steps