Optimizing RNNs with Differentiable Graphs
Jesse Engel
Baidu Silicon Valley AI Lab

placeholder

return to main page

1. Preamble

This is Part II of a multi-part series detailing some of the techniques we've used here at Baidu's Silicon Valley AI Lab (SVAIL) to accelerate the training of recurrent neural networks. While Part I focused on the role that minibatch and memory layout play on recurrent GEMM performance, we shift our focus here to tricks that we can use to optimize the algorithms themselves.

2. Summary

The main takeaways of this blog post are:

  1. Differentiable graphs are a simple and useful tool for visually calculating complicated derivatives.

  2. These graphs can also inspire algorithmic optimizations. As an example, we show how to accelerate Gated Recurrent Units (GRUs) by up to 40%.

    • Applying the reset gate after the recurrent multiply contributes ~10% of the speedup without loss in accuracy.

2.1. Audience

This post will hopefully be helpful for:

  1. Researchers using frameworks that require explicit gradient calculation, such as Torch or Caffe. We will see how to easily visualize and infer gradients in terms of GPU kernels.

  2. Researchers developing new iterative algorithms. We will develop variations of iterative algorithms such as RNNs that are more efficiently parallelized.

  3. The authors of Deep Learning frameworks that apply auto-differentiation such as Theano, Tensorflow, Torch Autograd, or Neon. These methods will hopefully provide inspiration for implicit graph optimizations to move towards systems that can better balance tradeoffs of memory usage and computation.

3. Introduction

One of the main reasons for the recent success and popularity of deep learning algorithms is that their performance scales well with dataset size. For example, in our recent work on Deep Speech 2, we found that the word error rate (WER) of our speech system decreases by 40% for every decade increase in data. We see that this power law trend holds over two orders of magnitude of data.

Scaling

Good performance is thus closely tied with being able to train quickly on increasing amounts of data. Deep learning's success is partly due to its core operations, such as matrix multiplication and convolution, having very efficient GPU kernel implementations (e.g. cuBLAS and cuDNN respectively).

Recurrent Neural Networks (RNNs) have iterative dependencies that make them well-suited for sequential tasks, but tricky to efficiently parallelize. Check out Part I of this series for an in-depth review of efficient size/minibatch design principles for RNNs. As with other “neural” operations, the most efficient RNNs require custom GPU implementations such as in cuDNNv5 or in our Persistent RNN kernels.

Imagine, however, that you are developing a new algorithm that doesn't yet have a custom implementaiton (part of the joy of machine learning research). In this blog post, using RNNs as a case study, we want to explore the question:

What optimizations can we perform when composing a new algorithm out of existing kernels?

4. Differentiable Graph Notation

We will assume the reader has basic familiarity with the backpropagation algorithm. While we avoid derivations in the current discussion for brevity, we highly encourage interested readers to read the derivations in the appendix.

We consider a computational graph composed of nested operations, $h(x)$ that turn input tensors, $x$, into output tensors, $h$. If these tensors are not parameters of the model, we call them “activations”. The derivative of a cost function, $J(x)$, with respect to any tensor in the graph, $\frac{dJ}{dx}$, can be decomposed into an “error” term, $\delta = \frac{dJ}{dh}$, and a derivative of the operation $\frac{dh}{dx}$. On backprop, the error term propagates from operation to operation, accumulating the derivative of each operation along the way. Thus, $\frac{dJ}{dx}$ becomes $\frac{dJ}{dh}$ for the next operation.

(1)
\[\frac{dJ}{dx} = \delta \frac{dh}{dx} = \delta_{next} \]

As an electrical engineer in a previous life, I'm biased to represent the activations as the edges of the graph (like voltages) and the functions as the nodes (like circuit elements). If we also express parameters as nodes, then all of our functions are stateless and the graph is explicit. I find that grounding the visualization in the physics of circuits to be helpful and is similar to approaches by Chris Olah and Andrej Karpathy.1

Let's look at the basic operations we'll use in most differentiable algorithms. Many more complicated operations can be decomposed into these fundamental kernels:

GraphRules1_alt

GraphRules2

where the top diagram is the forward prop rule (up and to the right), and the bottom is the backprop rule (down and to the left). Some things to note:

I've marked in red the operations where backprop also requires incoming activations. These are the operations for which we will need to decide whether to save or recompute these activations.

4.0.1. Example: Fully Connected Layer

Before we dive into optimizing RNNs, let's take a look at the simple example of a fully connected layer

(2)
\[h = f(Wx + b) \]

where $f$ is an elementwise nonlinearity, and $W$ and $b$ are weight and bias parameters respectively. We can thus draw the forward pass as

FC_forward

and then use our graph propagation rules to read out the derivatives from the backward pass.

FC_backward

(3)
\[\begin{aligned} &\delta_f \!\!\!\!&= \;\; &f'|_{h} \cdot \delta \\ &\delta_b \!\!\!\!&= \;\; &\delta_f \\ &\delta_W \!\!\!\!&= \;\; &\delta_f x^T \\ &\delta_x \!\!\!\!&= \;\; &W^T \delta_f &\end{aligned} \]

The error, $\delta$, gets passed through the nonlinearity (evaluated at $h$) to give $\delta_f$. At the sum, this then gets copied in both directions (to $b$ and the matrix multiply). At the matrix multiply, the incoming signals ($W$ and $x$ in this case) are transposed and sent “around the bend” to the other input.3

5. Optimizing RNNs

One way to achieve efficient GPU implementations is to try and make the system “compute bound”. If the bottleneck of your parallel algorithm is how fast the processor can run, and you're not doing any unnecessary computation, you're running at “the speed of light” for that processor and doing the best that you can do.

For large matrix multiplications, this is a relatively achievable goal as the amount of computation required scales with the cube of the matrix dimensions. The parallelization of the problem within a GEMM (General Matrix Matrix) kernel, means that the processors have sufficient work to do so that they don't have to wait on data to arrive.

However, the GEMMs within RNNs are typically smaller because they only act on one timestep at a time. By definition, each timestep depends on the previous timestep and can't be parallelized. These kernels then become “bandwidth bound” and run substantially slower than the processor's potential.4

As these small GEMMs are often the rate-limiting step for RNNs, this leads us to three guiding design principles to increase efficiency. These rules revolve around moving from bandwidth bound layers to compute bound layers.

  1. Concatenate GEMMs

    • Look for independence across dimensions (time, n_neurons) and combine multiple kernels to a single kernel. These larger multiplies relieve some of the bandwidth pressure on the kernels.
  2. Avoid Recomputing GEMMs

    • Adapt models to be functionally similar, but without small GEMMs. Choose stored activations properly to avoid recomputing small GEMMs on backprop.
  3. Use Fused GEMMs

    • When we have a sum after a GEMM, precompute the activations that will be summed and use the version of the kernel that acts in place. Since the Fused Multiply Add (FMA) is a fundamental operation on GPUs, we get the addition for free.

5.1. Example: Vanilla RNN

The “Vanilla” RNN is the classic form of a linear transformation of an input, $x_t$, and a previous hidden state, $h_{t-1}$, with elementwise nonlinearity, $f$, to produce the next hidden state, $h_t$.

(4)
\[h_t = f(Uh_{t-1} + Wx_t) \]

Let's draw out a single timestep, where $U$ is the recurrent weight matrix and $W$ is the input weight matrix

VanillaRNN1

Following our design principles, we can see right away that the input GEMMs are independent for all timesteps. Thus we can concatenate them together by precomputing the inputs for all timesteps ($Wx$). If our array layout is [dims, minibatch, time], no rearrangement is necessary. This is a common optimization employed in many frameworks.

VanillaRNN2

As an added bonus, if we precompute the input GEMMs, we can then use the Fused GEMM in our recurrence operate inplace and get the addition for free.

VanillaRNN3

Now let's take a look at the graph on backprop

VanillaRNN4

Since the $\delta$ values from the layer above are precomputed, we can do an inplace GEMM and get the addition for free again. Remember that a split (copy) of a wire on forward prop is equivalent to a sum on backprop.

VanillaRNN5

In a naive implementation, we would calculate the gradients with respect to the inputs and parameters at each timestep and accumulate them over time. But by now, we know better. Since each parameter leads into a matrix multiplication, we know that we will have to multiply by the transpose of the other input to get the final gradient, so why not do for all timesteps at once?

Since the sum sends a copy of $\delta$ both ways, we only really need to hold onto the derivatives after the nonlinearity ($\Delta _t$), and then multiply the transposes “around the bend” ($h$, $x$, $W$) to get the gradients ($\delta _U$, $\delta_W$, $\delta_x$).

VanillaRNN6

It's important to note that there is an offset between $\Delta$ and $h$, since $\Delta_t$ is coming from the next timestep and $h_{t-1}$ is coming from the previous timestep. The graph also tells us which activations we need to store on forward prop ($h$ and $x$) as they are required by the matrix multiplies and nonlinearity. Luckily, since we can evaluate the nonlinearity at its postactivations, we can use the same stored $h$ for the recurrent multiply and the nonlinearity on backprop.

5.2. Example: Gated Recurrent Units (GRU)

Now let's look at the more complicated Gated Recurrent Unit. We'll focus on the GRU rather than the more commonly used LSTM because it is a bit less straightforward and more interesting to optimize. Many of the tricks we use can be applied to LSTMs as well.

Along the way, we'll develop four different variants of the GRU that we will then compare in terms of speed and performance in the next section.

5.2.1. Variant 1: Standard GRU

GRUs differ from Vanilla RNNs in that they have two “gates” (reset gate $r_t$ and update gate $z_t$). Each gate has it's own input ($W_r$, $W_z$) and recurrent weights ($U_r$, $U_z$) and a sigmoid nonlinearity to constrain the output between 0 and 1 (No gain applied).

(5)
\[\begin{aligned} &z_t \!\!\!\! &= \;\; &\sigma(U_z h_{t-1} + W_z x_t) \\ &r_t \!\!\!\! &= \;\; &\sigma(U_r h_{t-1} + W_r x_t) \\ &\widetilde{h_t} \!\!\!\! &= \;\; &f(U(r_t \cdot h_{t-1}) + W x_t) \\ &h_t \!\!\!\! &= \;\; &z_t \cdot \widetilde{h_t} + (1-z_t) \cdot h_{t-1} &\end{aligned} \]

The reset gate modulates the recurrence, allowing for a ‘reset’ of the hidden state, and the update gate controls whether or not to update the hidden state of the RNN. If the $z_t$ is 0, then the hidden state remains constant, allowing long term dependencies. The forward pass appears a bit more complicated than a Vanilla RNN

GRU1

For simplicity, I've combined an elementwise multiplication of -1 and an element wise addition of 1 into $1-\{\}$, the derivative ($\frac{dh}{dx}$) of which is -1.

5.2.2. Variant 2: Concatenate across timesteps and gates

To start off, we can concatenate the input matrix multiplications in time like we did with Vanilla RNN. Further, each of the input gates is independent of each other so we can concatenate all the $W$ matrices together into a rectangular matrix. We can apply a similar trick to $U_z$ and $U_r$ (not shown), but not $U$, as it needs to first be multiplied by the output of the reset gate.

GRU2

Similar to the Vanilla RNN, we see that precomputing the inputs allows us to apply the recurrent GEMMs inplace, getting the elementwise addition for free.

GRU3

5.2.3. Variant 3: Move the Reset Gate

Not being able to concatenate $U$ with $U_r$ and $U_z$ is doubly painful because it requires us to perform a small GEMM on both forward prop and back prop (or save an additional activation). Notice that there are two multiplies in a row (elementwise then matrix), both of which require activations on backprop. The elementwise multiply can share activations with the nonlinearity of the reset gate, but $U$ cannot share with $h_{t-1}$ like the other recurrent multiplies.

For these reasons, it seems like it would be beneficial to swap the order of the matrix multiply, $U$, and elementwise multiplication by the reset gate, $r_t$.

(6)
\[\begin{aligned} &z_t \!\!\!\! &= \;\; &\sigma(U_z h_{t-1} + W_z x_t) \\ &r_t \!\!\!\! &= \;\; &\sigma(U_r h_{t-1} + W_r x_t) \\ &\widetilde{h_t} \!\!\!\! &= \;\; &f(r_t \cdot U h_{t-1} + W x_t) \\ &h_t \!\!\!\! &= \;\; &z_t \cdot \widetilde{h_t} + (1-z_t) \cdot h_{t-1} &\end{aligned} \]

While this changes the function being computed, it has similar functionality, and produces extremely similar training / accuracy performance in many of our experiments. It also makes all the recurrent multiplies independent of each other

GRU4

so that they can then all be concatenated. Notice, however, that we no longer can apply a Fused GEMM for $U$. It turns out to be a very good tradeoff to make a larger GEMM at the cost of an additional sum.

GRU5

Now let's examine the GRU on backprop

GRU6

Like the Vanilla RNN, Variant 3 of the GRU allows us use to an in-place GEMM to sum with the $\delta_t$ values from the layer above. Remember that a split (copy) of a wire on forward prop is equivalent to a sum on backprop.

GRU7

Remember that we need an activation (either through storage or recomputation) at each multiplication on nonlinearity. Storing activations at each nonlinearity seems like a good starting place, as they all have a multiplication afterwards that can share the activation. In the case of $z_t$ we need to do a very small recomputation ($1-\{\}$) to get to the elementwise multiply, which is not very limiting. However, we are left with still having two multiplications in a row for the case of $U$, so we need to either store an additional activation or recompute a small GEMM on backprop.

5.2.4. Variant 4: Save different activations

We can avoid this GEMM on backprop without additional memory consumption if we store activations at both multiplications, but not at the nonlinearity $\widetilde{h_t}$. If we store the activations at $h_{fwd}$ instead, we replace the recomputation of a small GEMM ($Uh_{t-1}$) with an elementwise multiply ($r_t$), sum ($Wx_t$), and nonlinearity ($f$).

GRU8

Intuitively, this will actually require less computation on backprop, but have additional fixed latency costs of launching more kernels.

6. GRU Performance Comparison

Of course, the only real way to know if an optimization is helpful is to measure it. We run a baseline comparison of for a series of different GRU sizes and a series of minibatch sizes. Each configuration is run for 10,000 iterations to get consistent statistics. We compare the original GRU (Variant 1, solid line) to the final variant (Variant 4, top of shaded section). The width of the shaded region represents the improvement.

BaselineGRU_Fill

Unsuprisingly, the efficiency of the layer, as measured in TFlops/second, increases dramatically as the sizes grow and the small GEMM kernels become less bandwidth bound. The GRU modifications help to offset the effect of the small matrix sizes, with the optimized GRU at minibatch 64 performing about as good as the original GRU with a minibatch of 128.5

Now, let's pick a single minibatch (128), and compare the four variants of GRU detailed above (Original, Concatenated, Moved Reset Gate, Different Saved Activations). We can find the relative speed of computation by comparing the total time spent inside the GPU kernels.

TimeInKernels

In general, each optimization improves the performance, with the total improvement ranging from 20-40%. It's interesting to note that for very large sizes (2048), the optimizations help less in a relative sense because the larger GEMMs are less of a bottleneck.

Conversely, the final optimization that saves recomputing a GEMM at the cost of a couple elementwise recomputations, gives more relative improvement with larger size. While this might seem counter intuitive, as GEMMs are more of a bottleneck at small sizes, this is overshadowed by the reduction in total computation required, which is greater for larger sizes.

7. Conclusion

To recap our takeaways

8. Acknowledgments

I would like to thank Shubho Sengupta for his help with profiling and implementing GRU variant #4.

9. Appendix: Differentiable Graph Derivations

back to the article

Let's start with an example. We'll look at the case of two nested functions (“layers”, F and G). This can easily be extended to any number of layers. We seek to use gradient descent to find the parameters, $\theta$, that minimize a composite “cost” function, $J$, given data, $x$.

(7)
\[J = G_{\theta_G} ( F_{\theta_F} (x) ) \]

Our goal is to find the derivatives of our cost with respect to our parameters ($\frac{dJ}{d\theta_G}, \frac{dJ}{d\theta_F}$). Using the chain rule:

(8)
\[\begin{aligned} \frac{dJ}{d\theta_G} &= \frac{dJ}{dG}\frac{dG}{d\theta_G} \\ \frac{dJ}{d\theta_F} &= \frac{dJ}{dG}\frac{dG}{dF}\frac{dF}{d\theta_F} \end{aligned} \]

We can see that there is no need to calculate $\frac{dJ}{dG}$ twice, and we can propagate this “error signal”, $\delta = \frac{dJ}{dG}$, backwards from layer G to layer F.

Let's extend this to an arbitrary layer, $h_\theta(x)$ that turns inputs, $x$, into outputs (activations), $h$, using parameters, $\theta$. On backprop, the propagating error derivative $\delta$, gives the gradient of the parameters with respect to the cost, and propagated along by multiplying by the layer's derivative of the outputs with respect to it's inputs.

(9)
\[\frac{dJ}{d\theta} = \delta \frac{dh}{d\theta} \]
(10)
\[\delta = \frac{dh}{dx} \delta \]

We want to develop a graph notation that encapsulates both aspects of this propagation. As we shall see, we will only need equation (10) if we require that all operations be stateless and not contain parameters.

As an aside, a common misconception is that the backpropagation algortihm is “just the chain rule”. Chis Olah's blog does a great job of highlighting how backpropagation is actually a dynamic programming algorithm that allows us to avoid calculating derivatives over every possible path from inputs to outputs (forward-mode differentiation, $\frac{d}{dx}$) in exchange for caching intermediate activations (reverse-mode differentiation, $\frac{dJ}{d}$). Also see the appendix of MacLaurin et al.

There is implicitly a tradeoff between memory and compute in backprop. If we didn't save any activations we would have to recompute all the activations up to each layer repeatedly, scaling polynomially with the number of layers. Saving all activations requires no recomputation and scales linearly with the number of layers. However, for memory constrained systems such as large RNNs, this is wasteful. Making clever choices of which activations to save and which to recompute can jointly optimize for memory and computation.

Let's look at the basic operations we'll use in most differentiable algorithms. Many more complicated operations can be decomposed into these fundamental kernels:

9.1. Sum

Adding two things together is a simple and effective way of combining information from different sources. We use it for adding a bias to fully connected layers and combining input and recurrent activations in a RNN. It's a fusing operation as we have two inputs and a single output.

(11)
\[h = x_1 + x_2 \]

As there are no parameters, following equation (10), we only need to find the derivatives with respect to both inputs ($\frac{dh}{dx_1}, \frac{dh}{dx_2}$).

(12)
\[\begin{aligned} \frac{dh}{dx_1} &= 1 &\frac{dh}{dx_2} &= 1 \\ \delta_1 &= \delta & \delta_2 &= \delta \end{aligned} \]

So we can see that on backprop, we actually just copy the error term and send it back through both inputs.

9.2. Copy

As a sum is to combining information, a copy is to sharing it. Copys are great for branching algorithms that send the data through multiple paths to be later combined. A copy is a splitting operation as it has one input and two outputs.

(13)
\[\begin{aligned} h_1 &= x & h_2 &= x \end{aligned} \]

For backprop, we need $\frac{dh}{dx}$, but we have two outputs. If we assume the final cost function depends on both outputs, $J(h_1(x), h_2(x))$, then by the chain rule:

(14)
\[\begin{aligned} \frac{dJ}{dx} &= \frac{dJ}{dh_1}\frac{dh_1}{dx} + \frac{dJ}{dh_2}\frac{dh_2}{dx} \\ \frac{dJ}{dx} &= \frac{dJ}{dh_1} \cdot 1 + \frac{dJ}{dh_2} \cdot 1 \\ \delta &= \delta_1 + \delta_2 \end{aligned} \]

We end up propagating the sum of the derivatives. It's interesting to note that the Sum and Copy operations are thus complimentary to each other. If you do one of forward prop, you do the other on backprop.

A nice way of thinking of this is that forward prop corresponds to creating an explanation of our data, and backprop corresponds to assigning responsibility within the model for a given explanation. If our explanation is right or wrong, we know which parameters are responsible. So a Sum says that a combination of two factors explain our data and are equally responsible. Conversely, a Copy says that a single factor explains our data in two different ways and has a responsibility for both explanatory paths.

9.3. Elementwise Multiply

An elementwise multiplication is an independent gate or gain on each dimension in the input signal. These multiplicative interactions are essential to more advanced forms of RNNs such as Gated Recurrent Units (GRU) and Long Short Term Memory Units (LSTM).

(15)
\[h = x_1 \cdot x_2 \]

As with addition, we need pass the derivatives to each input.

(16)
\[\begin{aligned} \frac{dh}{dx_1} &= x_2 &\frac{dh}{dx_2} &= x_1 \\ \delta_1 &= \delta \cdot x_2 & \delta_2 &= x_1 \cdot \delta \end{aligned} \]

Note that activations at each input are passed to the opposite input on backprop.

9.4. Matrix Multiply

Matrix multiplication is the core of deep learning. It is the primary operation of fully connected and recurrent layers, representing an all-to-all connectivity of a collection of input and output dimensions.

(17)
\[h = x_1 x_2 \]

Derivatives are basically the same as with elementwise multiplication, except for many dimensions. There are many rules for handling matrix derivatives, but a nice hack is to take derivatives like the matrices are 1-D variables, and then take the transpose to make sure the dimensions are correct.

For example, if $x_1, \delta_1 \in \mathbb{R} ^{N \times K}$ and $x_2, \delta_2 \in \mathbb{R} ^{K \times M}$, then $h, \delta \in \mathbb{R} ^{N \times M}$. Since $\delta_1 = \delta \frac{dh}{dx_1}$, it implies $\frac{dh}{dx_1} \in \mathbb{R} ^{M \times K}$ which is $x_2^T$.

(18)
\[\begin{aligned} \frac{dh}{dx_1} &= x_2^T &\frac{dh}{dx_2} &= x_1^T \\ \delta_1 &= \delta x_2^T & \delta_2 &= x_1^T \delta \end{aligned} \]

As a rule of thumb this amounts to preserving the order of the multiplication and taking the transpose activations in-place.

In the special case where $x_1 = W$, a parameter weight matrix.

(19)
\[h = Wx \]
(20)
\[\begin{aligned} \frac{dh}{dW} &= x^T &\frac{dh}{dx} &= W^T \\ \frac{dJ}{dW} &= \delta x^T & \frac{dJ}{dx} &= W^T \delta \end{aligned} \]

Notice that by specifying the parameter $W$ as one of the inputs, we solved for gradients of both the parameters and the inputs only using equation (10), the derivative with respect the inputs. This trick will help make our graph notation much cleaner.

9.5. Elementwise Nonlinearity

The elementwise nonlinearity is where the magic happens. Without it, we would only have a linear network, regardless of how many layers we added. Common choices include Sigmoid ($\sigma$), Tanh, and the many variants of rectifier (ReLU, pReLU, etc.), which we will simply represent by the function $f$. By the chain rule, the derivative of a nonlinearity is evaluated at the value prior to the nonlinearity

(21)
\[h = f(x) \]
(22)
\[\begin{aligned} \frac{dh}{dx} &= f'|_x \\ \delta &= f'|_x \cdot \delta \end{aligned} \]

As it turns out, the derivative of many nonlinearities are actually easier to evaluate at the postactivation ($f'|_h$) rather than the preactivation ($f'|_x$). For example, for a Sigmoid ($h = \sigma(x)$)

(23)
\[\begin{aligned} \frac{dh}{dx} &= h \cdot (1-h) \end{aligned} \]

And for a ReLU ($h = ReLU(x)$)

(24)
\[\begin{aligned} \frac{dh}{dx} &= h \end{aligned} \]

Note that like Elementwise Multiply and Matrix Multiply, Nonlinearities require saving or recomputing activations for backprop.

back to the article

return to main page


1.Note that this is in many ways opposite of notation for Directed Probabilistic Graphical Models where nodes represent both parameters and intermediate variables, and edges correspond to functional relationships. It is more similar in spirit to Factor Graphs, where the factors correspond to the functional “circuit elements”.

2.In this blog post, I will assume column-major or Fortran ordering of arrays. This is different than the default for Python or Torch, but a better match to BLAS libraries and consistent with prior blog posts. This means that the dimensions of activations will be [Dim, Minibatch, Time], and weights will be [Output, Input]. See part I for details.

3.A small side note about the bias. The bias is a vector (ex. $\mathbb{R} ^{M \times 1}$) that is added to a matrix (ex. $\mathbb{R} ^{M \times N}$), so it needs to be broadcast across the columns. This is mathematically equivalent to adding $b \vec{1}$ instead of $b$, where $\vec{1} \in \mathbb{R} ^{1 \times N}$ is a vector of ones. For simplicity, I have not included this implicit broadcasting in the computational graph, but following the graph propagation rules this would leave the derivative of the bias as $\delta_b = \delta_f \vec{1}^T$, which is just the same as summing over the column dimension. It's also easy to see that this is required to make the sizes of the dimensions match, where $\delta_f \in \mathbb{R} ^{M \times N}$ and $b, \delta_b \in \mathbb{R} ^{M \times 1}$, which requires that $\vec{1}^T \in \mathbb{R} ^{N \times 1}$. Basically, as long as the dimensions of your parameters and gradients match, you should be doing okay.

4.One approach to create more work per a kernel call is to increase the minibatch size, as each example can be computed independently. Alternatively, one can use kernels that are tuned for smaller matrix multiplications. Alternatively, you can use GEMM kernels tuned for smaller matrix sizes (see Figure 7 of our Deep Speech 2 paper). Persistent RNN kernels are also effective at decreasing the bandwidth demands by storing the weights in the thread memory so that they don't need to be reloaded each timestep.

5.All timings are done on a TitanX with GPUBoost disabled. This means the maximum performance of the card is 6.14TFlops/second - there are 3072 cores, each capable of 1 fused multiply-add (FMA) per clock, operating at an unboosted frequency of 1GHz.