Optimizing RNN performance

Part I: Investigating performance of GPU BLAS Libraries
Erich Elsen
Baidu Silicon Valley AI Lab

placeholder

return to main page

1. Preamble

This is part I of a multi-part series detailing some of the techniques we've used here at Baidu's Silicon Valley AI Lab to accelerate the training of recurrent neural networks. This part focuses on GEMM performance. Later entries might focus on how we parallelize across GPUs, working with half precision, efficient GPU CTC calculation, GRU and LSTM implementation tricks

2. Audience

There are four main groups that this blog post is targeted towards:

  1. People that find Deep Learning exciting and want to learn more about it.

  2. Researchers using a Deep Learning framework such as Torch7, Theano or Caffe. This post will help you pick appropriate sizes for your layers and mini-batches to achieve optimal performance. Picking incorrectly can slow down your training time by a factor of 4x. It will also help you make sure your RNN implementations are as efficient as possible.

  3. The authors of Deep Learning frameworks. There are better options than cuBLAS, and some tricks, that if integrated could result in a large speed increase for all of your users.

  4. Low level library implementors. This blog post will detail many of the typical sizes and patterns that are of interest to researchers working with RNNs.

3. Introduction

Most researchers engaging in Neural Network research have been using GPUs for training for some time now due to the speed advantage they have over CPUs. GPUs from NVIDIA are almost universally preferred because they come with high quality BLAS (cuBLAS) and convolution (cuDNN) libraries. Achieving optimal performance across a wide range of hardware and input sizes is extremely challenging for library writers and there has been some work outside of NVIDIA on libraries focused on achieving even better performance for problem sizes relevant to deep learning.

Scott Gray of Nervana Systems, has written high performance GEMM and space-domain convolution libraries for Maxwell architecture GPUs which are used in their high performance deep learning framework Neon. Facebook has focused on frequency-domain based techniques for their convolution libraries. We have also written some libraries internally for special cases not covered by any existing libraries.

Here at Baidu's Silicon Valley Artificial Intelligence Lab (SVAIL) we spend a lot of time tackling problems that have sequential data dependencies that are best modeled with Recurrent Neural Networks (RNNs). The cost of evaluating these networks is dominated by matrix-matrix multiplies (the GEMM operation in BLAS terminology).

This is in contrast to Neural Networks that work on images where the data dependencies are hierarchically local and the evaluation cost is almost entirely due to convolutions and related operations.

Most networks consist of a mix of both operations. For example, our speech system has at least one layer of convolution, but the evaluation cost is dominated by the recurrent portions of the network. And image networks have layers that are calculated using matrix multiplies, but they tend to be an insignificant part of the evaluation cost. This paper has a nice accounting of the flops in the various layers of image style convnets. For that reason we'll be focusing on GEMM operations in this blog post about RNN performance.

4. Mapping BLAS conventions to NN conventions

Common terminology to describe a matrix problem is the triple (M, N, K), which describes the sizes of the matrices involved, and the “op” which tells us which matrices (if any) are transposed. These two pieces of information combined with the datatype (double, single or half) determine the speed of the operation.

matrixexample1

When we have a fully connected layer in a NN, we generally care how many units there are at the input and how many units are at the output. Additionally, as part of the optimization we have a mini-batch hyper-parameter. The mapping between these numbers is straightforward:

We often ignore the minibatch when writing our algorithms down, since it is technically a choice of the optimization, which is why the typical equation you see for forward prop through a layer looks like a matrix-vector product $x_{out} = \sigma(Wx_{in} + b_{in})$. The matrix W is M rows by K columns, the vector-like $x_{in}$ is K rows by N columns and the vector-like $x_{out}$ is M rows by N columns. The addition of the bias term, $b_{in}$, and the evaluation of the non-linearity $\sigma$ have a minor affect on performance in most situations, so we will leave them out of discussions of performance.

5. RNN review

There are many introductions and tutorials on the web covering the neat things RNNs can do; one good one is by Andrej Karpathy.

To talk about the performance of RNNs, we just need to look at the equations for going forward and going backward to compute gradients.

The basic equations representing one forward update of a RNN from timestep $t$ to $t+1$ look like:

(1)
\[    z^{t} = Wx^{t} + Uh^{t-1}
\]
(2)
\[   h_{t} = \sigma (z^t)
\]

where $h$ is the hidden state of the RNN, $x$ is the input from the previous layer, $W$ is the weight matrix for the input and $U$ is the weight matrix for the recurrent connections.

The hidden weight matrix $U$ is necessarily square - the number of hidden units remains the same, so there are the same number of inputs as there are outputs, so M must always equal K. The input weight matrix, $W$ does not have to be square - it can connect an arbitrary number of input units to an arbitrary number of hidden units. Often recurrent layers are stacked and in this case the only non-square input matrices will be the one at the very beginning of the recurrent stack and the one at the very end. More complicated versions of RNNs like Gated Recurrent Units (GRUs) and Long Short-Term Memory (LSTMs) have more recurrent weight matrices but they are all also square.

The following picture shows how the forward update looks as multiple steps are chained together.

RNN_diagram

To discuss backpropagation, we need to introduce a cost function $J(\mathbf{h}) : \mathbb{R}^N \rightarrow \mathbb{R}$. We want to optimize the parameters in our network using an algorithm based on gradient descent, so we will need to find the derivative of this cost with respect to the parameters $U$ and $W$ of the network, ie $\frac{\partial J}{\partial U}$ and $\frac{\partial J}{\partial W}$.

The way that we usually do the computation is not so obvious if you just start taking derivatives, so for the curious readers, we think that following a derivation of back-propagation for a RNN and seeing how that maps to the computations we perform in practice could be quite useful. If you are interested, check it out here. Otherwise, the final result, is that we need to perform the four steps below, for a sequence with $N$ time steps.

  1. Set $\delta_{N-1} = \sigma '(z_N)\frac{\partial J}{\partial h_N}$

  2. Calculate $\delta_{N-2} = \sigma '(z_{N-1})\left( \frac{\partial J}{\partial h_{N-1}} + U^T \delta_{N-1} \right)$

  3. Repeat 2 until reaching $t=0$.

  4. Use equation (3) to get the actual gradients

(3)
\[   \frac{\partial J}{\partial U} = \sum_{t=0}^{t=N-1} \delta_t h_t^T
\]

The derivative with respect to $W$ is exactly the same except for the last step where $x_t$ replaces $h_t$ (you don't need to recalculate $\delta$).

And finally, if there was another layer below the one that was depicted, then we would need $\frac{\partial J}{\partial x_t}$, which if you go through a similar derivation, can be shown to be:

(4)
\[    \frac{\partial J}{\partial x_t} = W^T \delta_{t}
\]

In conclusion, the important operations (in terms of performance) will be these three variants of GEMM:

  1. the NN op for the forward-pass multiplications $Wx$ and $Uh$.
  2. the TN op for the backward-pass multiplications $U^T \delta$ and $W^T \delta$
  3. the NT op for the update calculations $\delta h^T$ and $\delta x^T$

5.1. Mini-batch

It is widely recognized that increasing the size of a mini-batch is important for decreasing the time to convergence of SGD because the increase in efficiency due to the larger batch size more than compensates for the increase in iterations required to reach a desired level of accuracy. This is true up to a certain point, which we have found to be at least a mini-batch of 512 in our system. Different networks, solving different problems with different data may see different results.

Unfortunately, the performance of GEMM libraries is not monotonically increasing with batch size as our mental model might indicate. There are hardware constraints and implementation choices that favor some sizes over others. It is important to examine the actual performance curves and choose batch sizes that yield the best performance.

Mini-batches are usually grouped by sequence length when training on problems with variable length sequences to avoid wasted work computing time steps that are valid for only a small number of elements in a mini-batch.

The upper bound on the size of a mini-batch is determined by the length of sequences in the training data, the available memory, the implementation of the memory allocator, and the data type used to store activations. The activations $x$ and $h$ must be saved for each time step of the RNN, so longer sequences require more memory. If the memory allocator is implemented as raw calls to cudaMalloc, then the mini-batch is constrained by the length of the longest sequence. Schemes with variable size mini-batches are possible, with smaller mini-batches for long sequences, but this is rarely done in practice. Custom memory allocators which can fall back to host-paged memory (memory that is on the host, but addressable from the GPU) for the longest sequences can allow sequences of essentially unlimited length to be trained at significant speed penalty for the sequences which are too long for GPU memory and fall back to host memory. This scheme makes the practical limitation on mini-batch defined by the 95th to 99th percentile of sequence lengths.

The above considerations mean that, in practice, mini-batches (per GPU) during training tend to be between 32 and 128. Use of half precision can extend the upper bound to 256. Common sizes for the number of hidden units in the recurrent layer range from 512 up to 2560, so the matrix multiplies that are important for RNN evaluation will be $M=K\in[512,2560]$ and $N\in[32, 256]$ for all ops.

6. Performance and Strategies for Improvement

All timings were done on a TitanX with GPUBoost disabled. This means the maximum performance of the card is 6.14TFlops - there are 3072 cores, each capable of 1 fused multiply-add (FMA) per clock, operating at an unboosted frequency of 1GHz. This number (the “speed of light”) serves as a reference guide for the maximum performance a gemm operation could achieve. All measurements were made with CUDA 7.0 and driver version 346.59 with the exception of the cuBLAS fp16 measurements which were made with the CUDA 7.5 beta and driver version 352.07 (the numbers for cuBLAS fp32 did not change significantly). Performance of cuBLAS on Kepler based cards (like the K40 and K80) will be significantly lower (2 to over 4x) compared to the performance of the Nervana kernels on Maxwell.

The Nervana GEMM library which is benchmarked below is available here. Both Python and C bindings are provided. The kernels themselves expect a row major data layout in contrast to cuBLAS, which expects a column major layout. The C interface has a wrapper that will take care of calling the kernels correctly with column major data.

The fp16 measurements are for a pseudo-fp16 operation, where the inputs and outputs are fp16, but all of the operations are performed in fp32. Using fp16 for storage of weights and activations has many advantages, but it can be non-trivial to get RNNs to converge when training with pseudo-fp16 operations. We will cover some relevant techniques in a later blog post.

Some plots leave out the fp16 performance for the sake of clarity when the pattern of performance remains the same.

6.1. Baseline performance

Assuming we implement our RNN exactly as we described it in the equations above we would get performance something like this:

NN2560

TN2560

We can see the clear advantage of getting to at least a mini-batch of 32, as performance increases by 14x with cuBLAS over a mini-batch of 1 and an amazing, nearly linear increase, of almost 30x for the Nervana kernels. cuBLAS has an unexpected performance regression going from a mini-batch of 8 to 9. It is impressive that the performance of the Nervana kernels is almost monotonically increasing with mini-batch size - matching the mental model of most users. Most users would not expect the oscillations present in the cuBLAS performance and avoid poorly performing sizes.

NN2048

NN1024

TN2048

TN1024

6.2. Combine across time whenever possible

One important observation is that the multiplications by $W$ in both (1) and (4) do not have any dependence on each other within a layer. This means all the timesteps, $T$, of those multiplications can be combined into one larger multiplication, which is advantageous because even though we will still be doing the same amount of work, larger multiplications tend to be more efficient than many small ones (extending the idea of mini-batches increasing performance).

Combining the multiplications can be facilitated by adopting a memory layout like that shown in the following figure:

combinetime

For each timestep the mini-batch is a contiguous chunk of memory, assuming a column-major ordering. Putting the timesteps adjacent in memory then results in a larger matrix where $N$ becomes $N\times T$.

The same technique can be applied to equation (3). Both matrices end up having dimensions $M \times NT$, so the output is $M \times M$ and the sum in (3) is done implicitly. Let's look at what this does to performance.

NN N4096

Performance is now constant at nearly the speed of light for both libraries - there is little dependence on M and K. This is also true for the calculation of the update, the NT op. For most combinations of sequence length and mini-batch (especially once $T\times mb \ge 1024$), there is almost no dependence on the size of the weight matrix. For a weight matrix of 512x512 cuBLAS falls behind Nervana, but otherwise performance is about the same.

NT op

Having made this optimization, the main performance bottleneck will be the recurrent multiplies, by $U$.

6.3. Use multiples of 32

Unfortunately, the realities of hardware make it such that the performance curves will not be monotonic. Vector loads, where many values are read at once, cause a preference for sizes that are multiples of 4, 8 or even 16 depending on the datatype, the operation and the dimension (M, N, K). Additionally, on GPUs, threads are grouped into units of 32 called warps, which can cause a strong preference for sizes that are multiples of 32. The easiest way to make sure you're getting the best performance possible is to have every dimension of your problem be of multiple of 32. Under cuBLAS the performance penalty can be over 4x for using a non-multiple of 32 - see the difference between (2560, 128, 2560) and (2560, 125, 2560)! Nervana is under a 25% penalty for most sizes.

Here the problem size is (2560, $\alpha 25$, 2560) vs. (2560, $\alpha 32$, 2560) for $\alpha$ an integer.

NN 2560 mb comparison

Here the problem size is (1000, $\alpha 25$, 1000):

NN1000

Here we compare (1000, $\alpha 25$, 1000) to (1024, $\alpha 32$, 1024):

NN 1k vs 1024

6.4. Take advantage of NN vs TN asymmetry

If there is a significant difference in speed between the NN and TN ops, then we can take advantage of that by explicitly transposing the weight matrix so that we can use the faster version. In the case of $M=K=2560$ the NN op is faster for many sizes than the TN op. So during backprop we can explicitly transpose the weight matrix $U$ (we can even do it in-place easily because U is square) and then call the NN version of the matrix multiply. The transposition cost is negligible compared to a long sequence of multiplies. This trick can even be applied to the large multiplies with $W$ operating over all the forward inputs.

Not all sizes exhibit an asymmetry, but when it does, why not take advantage of it?

NNvsTN 2560

NNvsTN 2048

NNvsTN 1024

6.5. Small mini-batches

Performance at small mini-batches [32-64] is especially important when using data parallelism to distribute the computation across GPUs. In synchronous stochastic gradient descent, the updates to the parameters (weights) are synchronized across all the GPUs for every mini-batch, such that there is only one, global, set of network parameters. The optimization is identical to doing a single large mini-batch on one machine (barring numerical differences due to changing the order of floating point operations) - but hopefully faster!

In asynchronous stochastic gradient descent (ASGD), there are multiple, different, copies of the network parameters which communicate with each other through a parameter server (exact details vary with the implementation). In this case, the concept of a global mini-batch is less clear. ASGD is used to reduce the cost of synchronizing the updates to the weights between the different copies of the model. In practice, we've found that one of the main downsides of ASGD is that runs are not reproducible. Repeating experiments exactly is impossible and this can make it very hard to determine the correctness of the code. We have found many subtle bugs by requiring our training code to produce the exact same result every time it is run. Sometimes the bugs will not lead to differences for thousands of iterations. The variation of ASGD completely masks such subtle differences. Additionally, it is possible to reduce the cost of the synchronization to a small part of the training time, reducing the main motivating factor for using ASGD.

If we have a global mini-batch of 512, then the mini-batch per GPU would only be 64 when using 8 GPUs and 32 when using 16 GPUs. The drop in performance of the matrix multiplies as the mini-batch gets smaller is a bigger hindrance to increasing parallelism than is the increase in communication cost for synchronizing the weight updates.

Performance at really small mini-batches is important in production (where only the NN op is used), where evaluation of a network will often be done with a mini-batch of one, which is an extremely inefficient use of the hardware. Batching is possible, but only to a limited extent, so achieving the best possible performance for tiny mini-batches is extremely important. In this regime most kernels take the same amount of time for increasing N, leading to the linear increase in performance as N increases.

Performance in this regime is important enough to us that we wrote a custom kernel that doubles the performance of cuBLAS for $N<=6$. Our kernel actually performs the TN op with an explicitly transposed weight matrix. The TN op kernels from both Nervana and cuBLAS perform worse than the NN op, so they are not included on the plot.

NN M,K 2560 small mb

7. Conclusion

The faster you can perform experiments, the faster you can learn from them and design better experiments - something we call virtuous cycle of innovation. We hope that this blog post will help RNN researchers and tinkerers run their experiments a little bit faster to reduce their time around the cycle.

For readers in first group (users of deep learning frameworks) - the main takeaway is to make layer sizes and mini-batch sizes multiples of 32, and if you're using cuBLAS, then make them multiples of 64 for best performance. If you're writing recurrent layers yourself, make sure to write them in such a way that you combine across time when possible.

For readers in the second group (writers of deep learning frameworks) - incorporating the Nervana kernels could provide a large speedup to all your users with a Maxwell based NVIDIA gpu. The transpose trick would be beneficial, but not as impactful.

For readers in the third group (writers of GEMM libraries) - performance at small batch sizes and near monotonically increasing performance with layer size and mini-batch size are two very important properties for the performance of training RNNs.

8. Acknowledgments

We would like to thank everyone at Nervana Systems for many fruitful and interesting discussions and especially Scott Gray for writing the entire higher performance GEMM library (in his spare time)!

9. Derivation

One piece of machinery we'll need is the chain rule when we're taking derivatives with respect to vectors and matrices rather than scalars. You can see the derivation here as equation F.25. What happens is that the order of the terms is actually opposite the way you probably learned in calculus, which is important because matrix multiplication doesn't commute.

(5)
\[  \mathbf{w} = \mathbf{z}(\mathbf{y}(\mathbf{x})))
\]
(6)
\[  \frac{\partial \mathbf{w}}{\partial \mathbf{x}} = \frac{\partial \mathbf{y}}{\partial \mathbf{x}}
                                                    \frac{\partial \mathbf{z}}{\partial \mathbf{y}}
                                                    \frac{\partial \mathbf{w}}{\partial \mathbf{z}}
\]

Now we can apply the multi-variate chain rule to $J$, assuming $N+1$ total time steps:

(7)
\[  \frac{\partial J}{\partial U} = \frac{\partial h_0}{\partial U}\frac{\partial J}{\partial h_0} +
                                  \frac{\partial h_1}{\partial U}\frac{\partial J}{\partial h_1} +
                                  \ldots +
                                  \frac{\partial h_N}{\partial U}\frac{\partial J}{\partial h_N}
\]

Note that whatever calculates $J$ must also calculate $\frac{\partial J}{\partial h_t}$. Let's look at just the last term for now (and slightly expand it):

(8)
\[  \frac{\partial J}{\partial U}=\frac{\partial z_N}{\partial U}
                                \frac{\partial h_N}{\partial z_N}
                                \frac{\partial J}{\partial h_N}
\]

It is instructive to look at the dimensions of each of these terms. The number of elements in the numerator is the number of columns and the number of elements in the denominator is the number of rows. Assuming $U$ is $M\times M$, then the first term which is the derivative of a vector with respect to a matrix is $M^2 \times M$; the second term, the derivative of a vector with respect to a vector (ie a Jacobian) is $M\times M$; and the third term, the derivative of a scalar with respect to a vector is $M \times 1$. This means the dimensions of the product will be $M^2\times 1$, which is as expected since the numerator has 1 element and the denominator has $M^2$ elements, but somewhat counter intuitive since we usually represent this term as a square $M\times M$ matrix in code.

The next step is to expand the first term.

(9)
\[   \frac{\partial z_N}{\partial U}=\frac{\partial (Uh_{N-1} + Wx_N)}{\partial U}=\frac{\partial (Uh_{N-1})}{\partial U}
\]

Unfortunately, the best way to see what happens now is to simply write out this derivative assuming $U$ is $2\times 2$ and $h$ is $2\times 1$.

(10)
\[  U = \begin{bmatrix}
      u_{00} & u_{01} \\
      u_{10} & u_{11}
      \end{bmatrix}
\]
(11)
\[  h = \begin{bmatrix}
      h_{0} \\
      h_{1}
      \end{bmatrix}
\]

If you write out the full product and take the derivatives, it is easy to see that (dropping the time subscript temporarily for clarity):

(12)
\[  \frac{\partial (Uh)}{\partial U}=\renewcommand{\arraystretch}{2.5}\begin{bmatrix}
                                   \dfrac{\partial h_0}{\partial u_{00}} & \dfrac{\partial h_1}{\partial u_{00}} \\
                                   \dfrac{\partial h_0}{\partial u_{01}} & \dfrac{\partial h_1}{\partial u_{01}} \\
                                   \dfrac{\partial h_0}{\partial u_{10}} & \dfrac{\partial h_1}{\partial u_{10}} \\
                                   \dfrac{\partial h_0}{\partial u_{11}} & \dfrac{\partial h_1}{\partial u_{11}}
                                   \end{bmatrix} \begin{bmatrix}
                                   u_{00} & u_{10} \\
                                   u_{01} & u_{11}
                                   \end{bmatrix} + \begin{bmatrix}
                                   h_0 & 0 \\
                                   h_1 & 0 \\
                                   0  & h_0 \\
                                   0  & h_1
                                   \end{bmatrix}
\]

or, more compactly:

(13)
\[  \frac{\partial (Uh)}{\partial U}=\frac{\partial h}{\partial U}U^T + \mathbf{H}
\]

Let's also take a look at the $\frac{\partial h_N}{\partial z_N}$ term in equation (8). It is technically a $M\times M$ matrix, but on closer examination we can see that the only non-zero terms are on the diagonal and the multiplication by $\frac{\partial J}{\partial h_N}$ can be done as a pointwise multiplication (or Hadamard product) between the diagonal of $\frac{\partial h_N}{\partial z_N}$ and the vector $\frac{\partial J}{\partial h_N}$. We often use the shorthand notation $\sigma '(z_N)$ for the diagonal of $\frac{\partial h_N}{\partial z_N}$.

Now let's combine the last two terms in equation (8) and call the resulting vector $\delta_{N-1}$ (we shift the index by one to make some indices match better in the coming terms). Now we can can write:

(14)
\[   \frac{\partial J}{\partial U}=\left (\frac{\partial h_{N-1}}{\partial U}U^T + \mathbf{H_{N-1}} \right )\delta_{N-1}
\]

Looking at the term $\mathbf{H_{N-1}}\delta_{N-1}$, we might wonder if we actually need to perform this full multiplication as $\mathbf{H}$ is mostly zeros. The answer is no, and by writing out the terms of the product, one can show that:

(15)
\[  \mathbf{H_{N-1}}\delta_{N-1} = \delta_{N-1} \mathbf{h_{N-1}}^{T}
\]

with the caveat that we need to reshape one side or the other to make the dimensions match - usually we prefer the $M\times M$ form to the $M^2\times 1$, since the matrix $U$ is $M\times M$. The calculation of this quantity is always done this way in practice. This term represents the contribution to $\frac{\partial J}{\partial U}$ from timestep $N$.

Now we can expand the $\frac{\partial h_{N-1}}{\partial U}$ term in equation (14) to get:

(16)
\[   \frac{\partial J}{\partial U}=\left (\frac{\partial z_{N-1}}{\partial U}
                                        \frac{\partial h_{N-1}}{\partial z_{N-1}} U^T + \mathbf{H_{N-1}} \right )\delta_{N-1}
\]

multiplying the delta through:

(17)
\[   \frac{\partial J}{\partial U}=\frac{\partial z_{N-1}}{\partial U}
                                  \sigma '(z_{N-1}) U^T \delta_{N-1} + \delta_{N-1} h_{N-1}^T
\]

Now is a convenient place to remember the second to last term from equation (7), $\frac{\partial z_{N-1}}{\partial U}\sigma '(z_{N-1})\frac{\partial J}{\partial h_{N-1}}$. By combining that term with equation (18), we get:

(18)
\[   \frac{\partial J}{\partial U}=\frac{\partial z_{N-1}}{\partial U}
                                  \sigma '(z_{N-1}) \left(\frac{\partial J}{\partial h_{N-1}} + U^T \delta_{N-1}\right) + \delta_{N-1} h_{N-1}^T
\]

now we set $\sigma '(z_N) (\frac{\partial J}{\partial h_{N-1}} + U^T \delta_{N-1})=\delta_{N-2}$, expand $\frac{\partial z_{N-1}}{\partial U}$ and incorporate the third to last term from (7) to get:

(19)
\[   \frac{\partial J}{\partial U}=\frac{\partial z_{N-2}}{\partial U}\sigma '(z_{N-2})
                              \left (\frac{\partial J}{\partial h_{N-2}} + U^T\delta_{N-2} \right ) +
                              \delta_{N-2}h_{N-2}^T + \delta_{N-1}h_{N-1}^T
\]

if we continue this process we will eventually get:

(20)
\[   \frac{\partial J}{\partial U} = \sum_{t=0}^{t=N-1} \delta_t h_t^T
\]

The steps for how to calculate the derivative of $J$ with respect to $U$ are now hopefully clear and enumerated back here.

return to main page

Created with Madoko.net.