.. SPDX-FileCopyrightText: 2019-2020 Intel Corporation
..
.. SPDX-License-Identifier: CC-BY-4.0
.. default-domain:: cpp
.. include:: ../replacements.inc.rst
.. _rnn-label:
###
RNN
###
The RNN primitive computes a stack of unrolled recurrent cells, as depicted in
Figure 1. :math:`\bias`, :math:`\srciter` and :math:`\dstiter` are optional
parameters. If not provided, :math:`\bias` and :math:`\srciter` default to 0.
Variable names follow the standard :ref:`conventions-label`.
.. image:: ../_static/unrolled_stack_rnn.jpg
The RNN primitive supports four modes for evaluation direction:
- ``left2right`` will process the input data timestamps by increasing order,
- ``right2left`` will process the input data timestamps by decreasing order,
- ``bidirectional_concat`` will process all the stacked layers from
``left2right`` and from ``right2left`` independently, and will concatenate
the output in :math:`\dstlayer` over the channel dimension,
- ``bidirectional_sum`` will process all the stacked layers from
``left2right`` and from ``right2left`` independently, and will sum the two
outputs to :math:`\dstlayer`.
Even though the RNN primitive supports passing a different number of channels
for :math:`\srclayer`, :math:`\srciter`, :math:`\dstlayer`, and
:math:`\dstiter`, we always require the following conditions in order for the
dimension to be consistent:
- :math:`channels(\dstlayer) = channels(\dstiter)`,
- when :math:`T > 1`, :math:`channels(\srciter) = channels(\dstiter)`,
- when :math:`L > 1`, :math:`channels(\srclayer) = channels(\dstlayer)`,
- when using the ``bidirectional_concat`` direction,
:math:`channels(\dstlayer) = 2 * channels(\dstiter)`.
The general formula for the execution of a stack of unrolled recurrent cells
depends on the current iteration of the previous layer (:math:`h_{t,l-1}` and
:math:`c_{t,l-1}`) and the previous iteration of the current layer
(:math:`h_{t-1, l}`). Here is the exact equation for non-LSTM cells:
.. math::
h_{t, l} = Cell(h_{t, l-1}, h_{t-1, l})
where :math:`t`, :math:`l` are the indices of the timestamp and the layer of
the cell being executed.
And here is the equation for LSTM cells:
.. math::
(h_{t, l},c_{t,l}) = Cell(h_{t, l-1}, h_{t-1, l}, c_{t-1,l})
where :math:`t`, :math:`l` are the indices of the timestamp and the layer of
the cell being executed.
**************
Cell Functions
**************
The RNN API provides six cell functions:
- :ref:`Vanilla RNN `, a single-gate recurrent cell,
- :ref:`LSTM `, a four-gate long short-term memory cell,
- :ref:`GRU `, a three-gate gated recurrent unit cell,
- :ref:`Linear-before-reset GRU `, a three-gate recurrent unit
cell with the linear layer before the reset gate.
- :ref:`AUGRU `, a three-gate gated recurrent unit cell
with the attention update gate,
- :ref:`Linear-before-reset AUGRU`, a three-gate
recurrent unit cell with the linear layer before the reset gate and
the attention update gate.
.. _vanilla_rnn-label:
Vanilla RNN
===========
A single-gate recurrent cell initialized with
|vanilla_rnn_forward::primitive_desc| or
|vanilla_rnn_forward::primitive_desc| as in the following example.
.. code:: cpp
auto vanilla_rnn_pd =
dnnl::vanilla_rnn_forward::primitive_desc(engine, aprop,
activation, direction, src_layer_desc, src_iter_desc,
weights_layer_desc, weights_iter_desc, bias_desc,
dst_layer_desc, dst_iter_desc, attr);
The Vanilla RNN cell should support the ReLU, Tanh and Sigmoid activation
functions. The following equations defines the mathematical operation
performed by the Vanilla RNN cell for the forward pass:
.. math::
a_t &= W \cdot h_{t,l-1} + U \cdot h_{t-1, l} + B \\
h_t &= activation(a_t)
.. _lstm-label:
LSTM
====
LSTM (or Vanilla LSTM)
----------------------
A four-gate long short-term memory recurrent cell initialized with
|lstm_forward::primitive_desc| or |lstm_backward::primitive_desc| as
in the following example.
.. code:: cpp
auto lstm_pd = dnnl::lstm_forward::primitive_desc(engine, aprop,
direction, src_layer_desc, src_iter_h_desc, src_iter_c_desc,
weights_layer_desc, weights_iter_desc, bias_desc,
dst_layer_desc, dst_iter_h_desc, dst_iter_c_desc, attr);
Note that for all tensors with a dimension depending on the gates number, we
implicitly require the order of these gates to be :math:`i`, :math:`f`,
:math:`\tilde c`, and :math:`o`. The following equation gives the mathematical
description of these gates and output for the forward pass:
.. math::
i_t &= \sigma(W_i \cdot h_{t,l-1} + U_i \cdot h_{t-1, l} + B_i) \\
f_t &= \sigma(W_f \cdot h_{t,l-1} + U_f \cdot h_{t-1, l} + B_f) \\
\\
\tilde c_t &= tanh(W_{\tilde c} \cdot h_{t,l-1} + U_{\tilde c} \cdot h_{t-1, l} + B_{\tilde c}) \\
c_t &= f_t * c_{t-1} + i_t * \tilde c_t \\
\\
o_t &= \sigma(W_o \cdot h_{t,l-1} + U_o \cdot h_{t-1, l} + B_o) \\
h_t &= tanh(c_t) * o_t
where :math:`W_*` are stored in :math:`\weightslayer`, :math:`U_*` are stored
in :math:`\weightsiter` and :math:`B_*` are stored in :math:`\bias`.
.. note::
In order for the dimensions to be consistent, we require
:math:`channels(\srciterc) = channels(\dstiterc) = channels(\dstiter)`.
LSTM with Peephole
------------------
A four-gate long short-term memory recurrent cell with peephole
initialized with |lstm_forward::primitive_desc| or
|lstm_backward::primitive_desc| as in the following example.
.. code:: cpp
auto lstm_pd = dnnl::lstm_forward::primitive_desc(engine, aprop,
direction, src_layer_desc, src_iter_h_desc, src_iter_c_desc,
weights_layer_desc, weights_iter_desc, weights_peephole_desc,
bias_desc, dst_layer_desc, dst_iter_h_desc, dst_iter_c_desc,
attr);
Similarly to vanilla LSTM, we implicitly require the order of these gates to
be :math:`i`, :math:`f`, :math:`\tilde c`, and :math:`o`. For peephole
weights, the gates order is:math:`i`, :math:`f`, :math:`o`. The following
equation gives the mathematical description of these gates and output for the
forward pass:
.. math::
i_t &= \sigma(W_i \cdot h_{t,l-1} + U_i \cdot h_{t-1, l} + P_i \cdot c_{t-1} + B_i) \\
f_t &= \sigma(W_f \cdot h_{t,l-1} + U_f \cdot h_{t-1, l} + P_f \cdot c_{t-1} + B_f) \\
\\
\tilde c_t &= tanh(W_{\tilde c} \cdot h_{t,l-1} + U_{\tilde c} \cdot h_{t-1, l} + B_{\tilde c}) \\
c_t &= f_t * c_{t-1} + i_t * \tilde c_t \\
\\
o_t &= \sigma(W_o \cdot h_{t,l-1} + U_o \cdot h_{t-1, l} + P_o \cdot c_t + B_o) \\
h_t &= tanh(c_t) * o_t
where :math:`P_*` are stored in ``weights_peephole``, and the other parameters
are the same as in vanilla LSTM.
.. note::
If the ``weights_peephole_desc`` passed to the primitive descriptor
constructor is a zero memory descriptor, the primitive will behave the same
as in LSTM primitive without peephole.
LSTM with Projection
--------------------
A four-gate long short-term memory recurrent cell with projection
initialized with |lstm_forward::primitive_desc| or
|lstm_backward::primitive_desc| as in the following example.
.. code:: cpp
auto lstm_pd = dnnl::lstm_forward::primitive_desc(engine, aprop,
direction, src_layer_desc, src_iter_h_desc, src_iter_c_desc,
weights_layer_desc, weights_iter_desc, weights_peephole_desc,
weights_projection_desc, bias_desc, dst_layer_desc,
dst_iter_h_desc, dst_iter_c_desc, attr);
Similarly to vanilla LSTM, we implicitly require the order of the gates to be
`i`, :math:`f`, :math:`\tilde c`, and :math:`o` for all tensors with a
dimension depending on the gates. The following equation gives the
mathematical description of these gates and output for the forward pass (for
simplicity, LSTM without peephole is shown):
.. math::
i_t &= \sigma(W_i \cdot h_{t,l-1} + U_i \cdot h_{t-1,l} + B_i) \\
f_t &= \sigma(W_f \cdot h_{t,l-1} + U_f \cdot h_{t-1,l} + B_f) \\
& \\
\tilde{c}_t &= \tanh(W_{\tilde{c}} \cdot h_{t,l-1} + U_{\tilde{c}} \cdot h_{t-1,l} + B_{\tilde{c}}) \\
c_t &= f_t * c_{t-1} + i_t * \tilde{c}_t \\
& \\
o_t &= \sigma(W_o \cdot h_{t,l-1} + U_o \cdot h_{t-1,l} + B_o) \\
h_t &= R \cdot (\tanh(c_t) * o_t)
where :math:`R` is stored in ``weights_projection``, and the other parameters
are the same as in vanilla LSTM.
.. note::
If the ``weights_projection_desc`` passed to the primitive descriptor
constructor is a zero memory descriptor, the primitive will behave the same
as in LSTM primitive without projection.
.. _gru-label:
GRU
===
A three-gate gated recurrent unit cell, initialized with
|gru_forward::primitive_desc| or |gru_backward::primitive_desc| as in
the following example.
.. code:: cpp
auto gru_pd = dnnl::gru_forward::primitive_desc(engine, aprop,
direction, src_layer_desc, src_iter_desc, weights_layer_desc,
weights_iter_desc, bias_desc, dst_layer_desc, dst_iter_desc,
attr);
Note that for all tensors with a dimension depending on the gates number, we
implicitly require the order of these gates to be:math:`u`, :math:`r`,
and :math:`o`. The following equation gives the mathematical definition of
these gates.
.. math::
u_t &= \sigma(W_u \cdot h_{t,l-1} + U_u \cdot h_{t-1, l} + B_u) \\
r_t &= \sigma(W_r \cdot h_{t,l-1} + U_r \cdot h_{t-1, l} + B_r) \\
o_t &= tanh(W_o \cdot h_{t,l-1} + U_o \cdot (r_t * h_{t-1, l}) + B_o) \\
h_t &= u_t * h_{t-1, l} + (1 - u_t) * o_t
where :math:`W_*` are in :math:`\weightslayer`, :math:`U_*` are in
:math:`\weightsiter`, and :math:`B_*` are stored in :math:`\bias`.
.. note::
If you need to replace :math:`u_t` by :math:`(1-u_t)` when computing
:math:`h_t`, you can achieve this by multiplying :math:`W_u`, :math:`U_u`
and :math:`B_u` by :math:`-1`. This is possible as :math:`u_t = \sigma(W_u
\cdot h_{t,l-1} + U_u \cdot h_{t-1, l} + B_u)`, and :math:`1 – \sigma(a) =
\sigma(-a)`.
.. _lbr_gru-label:
Linear-Before-Reset GRU
=======================
A three-gate gated recurrent unit cell with linear layer applied
before the reset gate, initialized with
|lbr_gru_forward::primitive_desc| or
|lbr_gru_backward::primitive_desc| as in the following example.
.. code:: cpp
auto lbr_gru_pd = dnnl::lbr_gru_forward::primitive_desc(engine,
aprop, direction, src_layer_desc, src_iter_desc,
weights_layer_desc, weights_iter_desc, bias_desc,
dst_layer_desc, dst_iter_desc, attr);
The following equation describes the mathematical behavior of the
Linear-Before-Reset GRU cell.
.. math::
u_t &= \sigma(W_u \cdot h_{t,l-1} + U_u \cdot h_{t-1, l} + B_u) \\
r_t &= \sigma(W_r \cdot h_{t,l-1} + U_r \cdot h_{t-1, l} + B_r) \\
o_t &= tanh(W_o \cdot h_{t,l-1} + r_t * (U_o \cdot h_{t-1, l} + B_{u'}) + B_o) \\
h_t &= u_t * h_{t-1, l} + (1 - u_t) * o_t
Note that for all tensors with a dimension depending on the gates number,
except the bias, we implicitly require the order of these gates to be
:math:`u`, :math:`r`, and :math:`o`. For the :math:`\bias` tensor, we
implicitly require the order of the gates to be :math:`u`, :math:`r`,
:math:`o`, and :math:`u'`.
.. note::
If you need to replace :math:`u_t` by :math:`(1-u_t)` when computing
:math:`h_t`, you can achieve this by multiplying :math:`W_u`, :math:`U_u`
and :math:`B_u` by :math:`-1`. This is possible as :math:`u_t = \sigma(W_u
\cdot h_{t,l-1} + U_u \cdot h_{t-1, l} + B_u)`, and :math:`1 – \sigma(a) =
\sigma(-a)`.
.. _augru-label:
AUGRU
=====
A three-gate gated recurrent unit cell, initialized with
|augru_forward::primitive_desc| or |augru_backward::primitive_desc| as
in the following example.
.. code:: cpp
auto augru_pd = dnnl::augru_forward::primitive_desc(engine, aprop,
direction, src_layer_desc, src_iter_desc, attention_desc,
weights_layer_desc, weights_iter_desc, bias_desc,
dst_layer_desc, dst_iter_desc, attr);
Note that for all tensors with a dimension depending on the gate
number, we implicitly require the order of these gates to be
:math:`u`, :math:`r`, and :math:`o`. The following equation gives the
mathematical definition of these gates.
.. math::
u_t &= \sigma(W_u \cdot h_{t,l-1} + U_u \cdot h_{t-1, l} + B_u) \\
r_t &= \sigma(W_r \cdot h_{t,l-1} + U_r \cdot h_{t-1, l} + B_r) \\
o_t &= \tanh(W_o \cdot h_{t,l-1} + U_o \cdot (r_t * h_{t-1, l}) + B_o) \\
\tilde u_t &= (1 - a_t) * u_t \\
h_t &= \tilde u_t * h_{t-1, l} + (1 - \tilde u_t) * o_t
where :math:`W_*` are in \weightslayer, :math:`U_*` are in
\weightsiter, and :math:`B_*` are stored in \bias.
.. _lbr_augru-label:
Linear-Before-Reset AUGRU
=========================
A three-gate gated recurrent unit cell with linear layer applied
before the reset gate, initialized with |lbr_augru_forward::primitive_desc| or
|lbr_augru_backward::primitive_desc| as in the following example.
.. code:: cpp
auto lbr_augru_pd =
dnnl::lbr_augru_forward::primitive_desc(engine, aprop,
direction, src_layer_desc, src_iter_desc, attention_desc,
weights_layer_desc, weights_iter_desc, bias_desc,
dst_layer_desc, dst_iter_desc, attr);
The following equation describes the mathematical behavior of the
Linear-Before-Reset AUGRU cell.
.. math::
u_t &= \sigma(W_u \cdot h_{t,l-1} + U_u \cdot h_{t-1, l} + B_u) \\
r_t &= \sigma(W_r \cdot h_{t,l-1} + U_r \cdot h_{t-1, l} + B_r) \\
o_t &= \tanh(W_o \cdot h_{t,l-1} + r_t *(U_o \cdot h_{t-1, l} + B_{u'}) + B_o) \\
\tilde u_t &= (1 - a_t) * u_t \\
h_t &= \tilde u_t * h_{t-1, l} + (1 - \tilde u_t) * o_t
Note that for all tensors with a dimension depending on the gate
number, except the bias, we implicitly require the order of these
gates to be :math:`u`, :math:`r`, and :math:`o`. For the \bias tensor,
we implicitly require the order of the gates to be :math:`u`,
:math:`r`, :math:`o`, and :math:`u'`.
*******************
Execution Arguments
*******************
When executed, the inputs and outputs should be mapped to an execution
argument index as specified by the following table.
=================================== ================================
Primitive input/output Execution argument index
=================================== ================================
:math:`\srclayer` |DNNL_ARG_SRC_LAYER|
:math:`\srciter` |DNNL_ARG_SRC_ITER|
:math:`\srciterc` |DNNL_ARG_SRC_ITER_C|
:math:`\weightslayer` |DNNL_ARG_WEIGHTS_LAYER|
:math:`\weightsiter` |DNNL_ARG_WEIGHTS_ITER|
:math:`\weightspeephole` |DNNL_ARG_WEIGHTS_PEEPHOLE|
:math:`\weightsprojection` |DNNL_ARG_WEIGHTS_PROJECTION|
:math:`\bias` |DNNL_ARG_BIAS|
:math:`\dstlayer` |DNNL_ARG_DST_LAYER|
:math:`\dstiter` |DNNL_ARG_DST_ITER|
:math:`\dstiterc` |DNNL_ARG_DST_ITER_C|
:math:`\workspace` |DNNL_ARG_WORKSPACE|
:math:`\diffsrclayer` |DNNL_ARG_DIFF_SRC_LAYER|
:math:`\diffsrciter` |DNNL_ARG_DIFF_SRC_ITER|
:math:`\diffsrciterc` |DNNL_ARG_DIFF_SRC_ITER_C|
:math:`\diffweightslayer` |DNNL_ARG_DIFF_WEIGHTS_LAYER|
:math:`\diffweightsiter` |DNNL_ARG_DIFF_WEIGHTS_ITER|
:math:`\diffweightspeephole` |DNNL_ARG_DIFF_WEIGHTS_PEEPHOLE|
:math:`\diffweightsprojection` |DNNL_ARG_DIFF_WEIGHTS_PROJECTION|
:math:`\diffbias` |DNNL_ARG_DIFF_BIAS|
:math:`\diffdstlayer` |DNNL_ARG_DIFF_DST_LAYER|
:math:`\diffdstiter` |DNNL_ARG_DIFF_DST_ITER|
:math:`\diffdstiterc` |DNNL_ARG_DIFF_DST_ITER_C|
=================================== ================================
*****************
Operation Details
*****************
N/A
******************
Data Types Support
******************
The following table lists the combination of data types that should be
supported by the RNN primitive for each input and output memory object.
.. note::
Here we abbreviate data types names for readability. For example, |_f32| is
abbreviated to |f32|.
+------------------------+--------------+-----------+---------------+-------------+----------+-------------+
| **Propagation** | **Cell** | **Input** | **Recurrent** | **Weights** | **Bias** | **Output** |
| | **Function** | **Data** | **Data** (1) | | | **Data** |
+------------------------+--------------+-----------+---------------+-------------+----------+-------------+
| Forward / Backward | All | |f32| | |f32| | |f32| | |f32| | |f32| |
+------------------------+--------------+-----------+---------------+-------------+----------+-------------+
| Forward / Backward (2) | All (3) | |bf16| | |bf16| | |bf16| | |f32| | |bf16| |
+------------------------+--------------+-----------+---------------+-------------+----------+-------------+
| Forward | All (3) | |f16| | |f16| | |f16| | |f16| | |f16| |
+------------------------+--------------+-----------+---------------+-------------+----------+-------------+
| Forward inference | Vanilla LSTM | |u8| | |u8| | |s8| | |f32| | |u8|, |f32| |
+------------------------+--------------+-----------+---------------+-------------+----------+-------------+
(1) With LSTM and Peephole LSTM cells, the cell state data type is always f32.
(2) In backward propagation, all ``diff_*`` tensors are in f32.
(3) Projection LSTM is not defined yet.
.. TODO: clarify if int8 lstm projection is now defined and clarify it.
Data Representation
===================
In the oneDNN programming model, the RNN primitive is one of a few that
support the placeholder memory format #dnnl::memory::format_tag::any
(shortened to ``any`` from now on) and can define data and weight memory
objects format based on the primitive parameters.
The following table summarizes the data layouts supported by the RNN
primitive.
+------------------+---------------+-------------------------+----------------------+-----------------------+
| **Input/Output** | **Recurrent** | **Layer and Iteration** | **Peephole Weights** | **Projection LSTM** |
| **Data** | **Data** | **Weights** | **and Bias** | **Weights** |
+------------------+---------------+-------------------------+----------------------+-----------------------+
| |any| | |any| | |any| | |ldgo| | |any|, |ldio| |
| | | | | (Forward propagation) |
+------------------+---------------+-------------------------+----------------------+-----------------------+
| |ntc|, |tnc| | |ldnc| | |ldigo|, |ldgoi| | |ldgo| | |any|, |ldio| |
| | | | | (Forward propagation) |
+------------------+---------------+-------------------------+----------------------+-----------------------+
While an RNN primitive can be created with memory formats specified
explicitly, the performance is likely to be sub-optimal. When using |any| it
is necessary to first create an RNN primitive descriptor and then query it for
the actual data and weight memory objects formats.
.. note::
The RNN primitive should support padded tensors and views. So even if two
memory descriptors share the same data layout, they might still be
different.
Post-ops and Attributes
=======================
Currently post-ops and attributes are only used by the int8 variant of LSTM.
.. TODO quantization
***
API
***
.. doxygenenum:: dnnl::rnn_flags
:project: oneDNN
.. doxygenenum:: dnnl::rnn_direction
:project: oneDNN
.. doxygenstruct:: dnnl::vanilla_rnn_forward
:project: oneDNN
:members:
.. doxygenstruct:: dnnl::vanilla_rnn_backward
:project: oneDNN
:members:
.. doxygenstruct:: dnnl::lstm_forward
:project: oneDNN
:members:
.. doxygenstruct:: dnnl::lstm_backward
:project: oneDNN
:members:
.. doxygenstruct:: dnnl::gru_forward
:project: oneDNN
:members:
.. doxygenstruct:: dnnl::gru_backward
:project: oneDNN
:members:
.. doxygenstruct:: dnnl::lbr_gru_forward
:project: oneDNN
:members:
.. doxygenstruct:: dnnl::lbr_gru_backward
:project: oneDNN
:members:
.. vim: ts=3 sw=3 et spell spelllang=en