Fast Image Processing with MXNet¶
Previous tutorials have shown two ways of preprocessing images:
mx.io.ImageRecordIteris fast but inflexible. It is great for simple tasks like image recognition but won’t work for more complex tasks like detection and segmentation.
skimage, etc) +
numpyis flexible but slow. The root of the problem is python’s Global Interpreter Lock (GIL). GIL is a complicated topic but the gist is python doesn’t really support multi-threading; even if you spawn multiple threads execution will still be serialized. You can workaround it with multi-processing and message passing, but it’s hard to program and introduces overhead.
To solve this issue, MXNet provides
mx.image package. It stores images in mx.nd.NDArray format and leverages MXNet’s dependency engine to automatically parallelize image processing and circumvent GIL.
Please read Intro to NDArray first if you are not familar with it.
%matplotlib inline from __future__ import print_function import os import time # set the number of threads you want to use before importing mxnet os.environ['MXNET_CPU_WORKER_NTHREADS'] = '4' import mxnet as mx import numpy as np import matplotlib.pyplot as plt import cv2
# download example images os.system('wget http://data.mxnet.io/data/test_images.tar.gz') os.system('tar -xf test_images.tar.gz')
First we load images with mx.image.imdecode. The interface is very similar to opencv. But everything runs in parallel.
We also compare performance agains opencv. You can restart the kernel and change MXNET_CPU_WORKER_NTHREADS to see the performance improvement as more threads are used.
# opencv N = 1000 tic = time.time() for i in range(N): img = cv2.imread('test_images/ILSVRC2012_val_00000001.JPEG', flags=1) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) print(N/(time.time()-tic), 'images decoded per second with opencv') plt.imshow(img); plt.show()
263.867350651 images decoded per second with opencv
# mx.image tic = time.time() for i in range(N): img = mx.image.imdecode(open('test_images/ILSVRC2012_val_00000001.JPEG').read()) mx.nd.waitall() print(N/(time.time()-tic), 'images decoded per second with mx.image') plt.imshow(img.asnumpy()); plt.show()
951.311962346 images decoded per second with mx.image
Once images are loaded as NDArray, you can then use
mx.nd.* operators to transform them. mx.image provides utility functions for some typical transformations:
# resize to w x h tmp = mx.image.imresize(img, 100, 70) plt.imshow(tmp.asnumpy()); plt.show()
# resize shorter edge to size while preserving aspect ratio tmp = mx.image.resize_short(img, 100) plt.imshow(tmp.asnumpy()); plt.show()
# crop a random w x h region from image tmp, coord = mx.image.random_crop(img, (150, 200)) print(coord) plt.imshow(tmp.asnumpy()); plt.show()
(267, 141, 150, 200)
Other utility functions include
Given the above functionalities, it’s easy to write a custom data iterator. As an example, we provide
mx.image.ImageIter, which is similar to Torch’s resnet image loading pipeline. For more details, please see