Char RNN Example

This tutorial shows how to use an LSTM model to build a char-level language model, and generate text from it. For demonstration purposes, we use a Shakespearean text. You can find the data on GitHub.

Load the Data

Load in the data and preprocess it:

   require(mxnet)
   ## Loading required package: mxnet
   ## Loading required package: methods

Set the basic network parameters:

   batch.size = 32
   seq.len = 32
   num.hidden = 16
   num.embed = 16
   num.lstm.layer = 1
   num.round = 1
   learning.rate= 0.1
   wd=0.00001
   clip_gradient=1
   update.period = 1

Download the data:

   download.data <- function(data_dir) {
       dir.create(data_dir, showWarnings = FALSE)
       if (!file.exists(paste0(data_dir,'input.txt'))) {
           download.file(url='https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/tinyshakespeare/input.txt',
                         destfile=paste0(data_dir,'input.txt'), method='wget')
       }
   }

Make a dictionary from the text:

   make.dict <- function(text, max.vocab=10000) {
       text <- strsplit(text, '')
       dic <- list()
       idx <- 1
       for (c in text[[1]]) {
           if (!(c %in% names(dic))) {
               dic[[c]] <- idx
               idx <- idx + 1
           }
       }
       if (length(dic) == max.vocab - 1)
           dic[["UNKNOWN"]] <- idx
       cat(paste0("Total unique char: ", length(dic), "\n"))
       return (dic)
    }

Transfer the text into a data feature:

   make.data <- function(file.path, seq.len=32, max.vocab=10000, dic=NULL)      {
       fi <- file(file.path, "r")
       text <- paste(readLines(fi), collapse="\n")
       close(fi)

       if (is.null(dic))
           dic <- make.dict(text, max.vocab)
       lookup.table <- list()
       for (c in names(dic)) {
           idx <- dic[[c]]
           lookup.table[[idx]] <- c
        }

       char.lst <- strsplit(text, '')[[1]]
       num.seq <- as.integer(length(char.lst) / seq.len)
       char.lst <- char.lst[1:(num.seq * seq.len)]
       data <- array(0, dim=c(seq.len, num.seq))
       idx <- 1
       for (i in 1:num.seq) {
            for (j in 1:seq.len) {
                if (char.lst[idx] %in% names(dic))
                   data[j, i] <- dic[[ char.lst[idx] ]]-1
               else {
                   data[j, i] <- dic[["UNKNOWN"]]-1
               }
               idx <- idx + 1
           }
       }
        return (list(data=data, dic=dic, lookup.table=lookup.table))
   }

Move the tail text:

   drop.tail <- function(X, batch.size) {
       shape <- dim(X)
       nstep <- as.integer(shape[2] / batch.size)
       return (X[, 1:(nstep * batch.size)])
   }

Get the label of X:

   get.label <- function(X) {
       label <- array(0, dim=dim(X))
       d <- dim(X)[1]
       w <- dim(X)[2]
       for (i in 0:(w-1)) {
           for (j in 1:d) {
               label[i*d+j] <- X[(i*d+j)%%(w*d)+1]
           }
       }
       return (label)
   }

Get the training data and evaluation data:

   download.data("./data/")
   ret <- make.data("./data/input.txt", seq.len=seq.len)
   ## Total unique char: 65
   X <- ret$data
   dic <- ret$dic
   lookup.table <- ret$lookup.table

   vocab <- length(dic)

   shape <- dim(X)
   train.val.fraction <- 0.9
   size <- shape[2]

   X.train.data <- X[, 1:as.integer(size * train.val.fraction)]
   X.val.data <- X[, -(1:as.integer(size * train.val.fraction))]
   X.train.data <- drop.tail(X.train.data, batch.size)
   X.val.data <- drop.tail(X.val.data, batch.size)

   X.train.label <- get.label(X.train.data)
   X.val.label <- get.label(X.val.data)

   X.train <- list(data=X.train.data, label=X.train.label)
   X.val <- list(data=X.val.data, label=X.val.label)

Train the Model

In mxnet, we have a function called mx.lstm so that users can build a general LSTM model:

   model <- mx.lstm(X.train, X.val,
                    ctx=mx.cpu(),
                    num.round=num.round,
                    update.period=update.period,
                    num.lstm.layer=num.lstm.layer,
                    seq.len=seq.len,
                    num.hidden=num.hidden,
                    num.embed=num.embed,
                    num.label=vocab,
                    batch.size=batch.size,
                    input.size=vocab,
                    initializer=mx.init.uniform(0.1),
                    learning.rate=learning.rate,
                    wd=wd,
                    clip_gradient=clip_gradient)
   ## Epoch [31] Train: NLL=3.53787130224343, Perp=34.3936275728271
   ## Epoch [62] Train: NLL=3.43087958036949, Perp=30.903813186055
   ## Epoch [93] Train: NLL=3.39771238228587, Perp=29.8956319855751
   ## Epoch [124] Train: NLL=3.37581711716687, Perp=29.2481732041015
   ## Epoch [155] Train: NLL=3.34523331338447, Perp=28.3671933405139
   ## Epoch [186] Train: NLL=3.30756356274787, Perp=27.31848454823
   ## Epoch [217] Train: NLL=3.25642968403829, Perp=25.9566978956055
   ## Epoch [248] Train: NLL=3.19825967486207, Perp=24.4898727477925
   ## Epoch [279] Train: NLL=3.14013971549828, Perp=23.1070950525017
   ## Epoch [310] Train: NLL=3.08747601837462, Perp=21.9216781782189
   ## Epoch [341] Train: NLL=3.04015595674863, Perp=20.9085038031042
   ## Epoch [372] Train: NLL=2.99839339255659, Perp=20.0532932584534
   ## Epoch [403] Train: NLL=2.95940091012609, Perp=19.2864139984503
   ## Epoch [434] Train: NLL=2.92603311380224, Perp=18.6534872738302
   ## Epoch [465] Train: NLL=2.89482756896395, Perp=18.0803835531869
   ## Epoch [496] Train: NLL=2.86668230478397, Perp=17.5786009078994
   ## Epoch [527] Train: NLL=2.84089368534943, Perp=17.1310684830416
   ## Epoch [558] Train: NLL=2.81725862932279, Perp=16.7309220880514
   ## Epoch [589] Train: NLL=2.79518870141492, Perp=16.3657166956952
   ## Epoch [620] Train: NLL=2.77445683225304, Perp=16.0299176962855
   ## Epoch [651] Train: NLL=2.75490970113174, Perp=15.719621374694
   ## Epoch [682] Train: NLL=2.73697900634351, Perp=15.4402696117257
   ## Epoch [713] Train: NLL=2.72059739336781, Perp=15.1893935780915
   ## Epoch [744] Train: NLL=2.70462837571585, Perp=14.948760335793
   ## Epoch [775] Train: NLL=2.68909904683828, Perp=14.7184093476224
   ## Epoch [806] Train: NLL=2.67460054451836, Perp=14.5065539595711
   ## Epoch [837] Train: NLL=2.66078997776751, Perp=14.3075873113043
   ## Epoch [868] Train: NLL=2.6476781639279, Perp=14.1212134100373
   ## Epoch [899] Train: NLL=2.63529039846876, Perp=13.9473621677371
   ## Epoch [930] Train: NLL=2.62367693518974, Perp=13.7863219168709
   ## Epoch [961] Train: NLL=2.61238282674384, Perp=13.6314936713501
   ## Iter [1] Train: Time: 10301.6818172932 sec, NLL=2.60536539345356, Perp=13.5361704272949
   ## Iter [1] Val: NLL=2.26093848746227, Perp=9.59208699731232

Build Inference from the Model

Use the helper function for random sample:

   cdf <- function(weights) {
       total <- sum(weights)
       result <- c()
       cumsum <- 0
       for (w in weights) {
           cumsum <- cumsum+w
           result <- c(result, cumsum / total)
       }
       return (result)
   }

   search.val <- function(cdf, x) {
       l <- 1
       r <- length(cdf)
       while (l <= r) {
           m <- as.integer((l+r)/2)
           if (cdf[m] < x) {
               l <- m+1
           } else {
               r <- m-1
           }
       }
       return (l)
   }
   choice <- function(weights) {
       cdf.vals <- cdf(as.array(weights))
       x <- runif(1)
       idx <- search.val(cdf.vals, x)
       return (idx)
   }

Use random output or fixed output by choosing the greatest probability:

   make.output <- function(prob, sample=FALSE) {
       if (!sample) {
           idx <- which.max(as.array(prob))
       }
       else {
           idx <- choice(prob)
       }
       return (idx)

   }

In mxnet, we have a function called mx.lstm.inference so that users can build an inference from an LSTM model, and then use the mx.lstm.forward function to get forward output from the inference.

Build an inference from the model:

   infer.model <- mx.lstm.inference(num.lstm.layer=num.lstm.layer,
                                    input.size=vocab,
                                    num.hidden=num.hidden,
                                    num.embed=num.embed,
                                    num.label=vocab,
                                    arg.params=model$arg.params,
                                    ctx=mx.cpu())

Generate a sequence of 75 characters using the mx.lstm.forward function:

   start <- 'a'
   seq.len <- 75
   random.sample <- TRUE

   last.id <- dic[[start]]
   out <- "a"
   for (i in (1:(seq.len-1))) {
       input <- c(last.id-1)
       ret <- mx.lstm.forward(infer.model, input, FALSE)
       infer.model <- ret$model
       prob <- ret$prob
       last.id <- make.output(prob, random.sample)
       out <- paste0(out, lookup.table[[last.id]])
   }
   cat (paste0(out, "\n"))

The result:

   ah not a drobl greens
   Settled asing lately sistering sounted to their hight

Create Other RNN Models

In mxnet, other RNN models, like custom RNN and GRU, are also provided:

  • For a custom RNN model, you can replace mx.lstm with mx.rnn to train an RNN model. You can replace mx.lstm.inference and mx.lstm.forward with mx.rnn.inference and mx.rnn.forward to build inference from an RNN model and get the forward result from the inference model.
  • For a GRU model, you can replace mx.lstm with mx.gru to train a GRU model. You can replace mx.lstm.inference and mx.lstm.forward with mx.gru.inference and mx.gru.forward to build inference from a GRU model and get the forward result from the inference model.