Batch Normalization#

The batch normalization primitive performs a forward or backward batch normalization operation on tensors with number of dimensions equal to 2 or more. Variable names follow the standard Conventions.

The batch normalization operation is defined by the following formulas. We show formulas only for 2D spatial data which are straightforward to generalize to cases of higher and lower dimensions.

The different flavors of the primitive are controlled by the flags parameter that is passed to the primitive descriptor initialization function like dnnl::batch_normalization_forward::primitive_desc. Multiple flags can be combined using the bitwise OR operator (|).

Forward#

\[\dst(n, c, h, w) = \gamma(c) \cdot \frac{\src(n, c, h, w) - \mu(c)} {\sqrt{\sigma^2(c) + \varepsilon}} + \beta(c),\]

where

  • \(\gamma(c)\) and \(\beta(c)\) are optional scale and shift for a channel (controlled using the use_scale and use_shift flags),

  • \(\mu(c)\) and \(\sigma^2(c)\) are mean and variance for a channel (controlled using the use_global_stats flag), and

  • \(\varepsilon\) is a constant to improve numerical stability.

Mean and variance are computed at runtime or provided by a user. When mean and variance are computed at runtime, the following formulas are used:

  • \(\mu(c) = \frac{1}{NHW} \sum\limits_{nhw} \src(n, c, h, w)_{}\),

  • \(\sigma^2(c) = \frac{1}{NHW} \sum\limits_{nhw} {}_{} (\src(n, c, h, w) - \mu(c))^2\).

The \(\gamma(c)\) and \(\beta(c)\) tensors are considered learnable.

In the training mode, the primitive also optionally supports fusion with ReLU activation with zero negative slope applied to the result (see fuse_norm_relu flag).

Note

The batch normalization primitive computes population mean and variance and not the sample or unbiased versions that are typically used to compute running mean and variance. * Using the mean and variance computed by the batch normalization primitive, running mean and variance \(\hat\mu_i\) and \(\hat\sigma^2_i\) where \(i\) is iteration number, can be computed as:

\[\begin{split}\hat\mu_{i+1} = \alpha \cdot \hat\mu_i + (1 - \alpha) \cdot \mu, \\ \hat\sigma^2_{i+1} = \alpha \cdot \hat\sigma^2_i + (1 - \alpha) \cdot \sigma^2.\end{split}\]

Difference Between Forward Training and Forward Inference#

  • If mean and variance are computed at runtime (i.e., use_global_stats is not set), they become outputs for the propagation kind forward_training (because they would be required during the backward propagation) and are not exposed for the propagation kind forward_inference.

  • If batch normalization is created with ReLU fusion (i.e., fuse_norm_relu is set), for the propagation kind forward_training the primitive would produce a workspace memory as one extra output. This memory is required to compute the backward propagation. When the primitive is executed with propagation kind forward_inference, the workspace is not produced. Behavior would be the same as creating a batch normalization primitive with ReLU as a post-op (see section below).

Backward#

The backward propagation computes \(\diffsrc(n, c, h, w)\), \(\diffgamma(c)^*\), and \(\diffbeta(c)^*\) based on \(\diffdst(n, c, h, w)\), \(\src(n, c, h, w)\), \(\mu(c)\), \(\sigma^2(c)\), \(\gamma(c) ^*\), and \(\beta(c) ^*\).

The tensors marked with an asterisk are used only when the primitive is configured to use \(\gamma(c)\) and \(\beta(c)\) (i.e., use_scale and use_shift are set).

Execution Arguments#

Depending on the flags and propagation kind, the batch normalization primitive requires different inputs and outputs. For clarity, a summary is shown below.

forward_inference

forward_training

backward

backward_data

none

In: \(\src\); Out: \(\dst\)

In: \(\src\); Out: \(\dst\), \(\mu\), \(\sigma^2\)

In: \(\diffdst\), \(\src\), \(\mu\), \(\sigma^2\); Out: \(\diffsrc\)

Same as for backward

use_global_stats

In: \(\src\), \(\mu\), \(\sigma^2\); Out: \(\dst\)

In: \(\src\), \(\mu\), \(\sigma^2\); Out: \(\dst\)

In: \(\diffdst\), \(\src\), \(\mu\), \(\sigma^2\); Out: \(\diffsrc\)

Same as for backward

use_scale

In: \(\src\), \(\gamma\); Out: \(\dst\)

In: \(\src\), \(\gamma\); Out: \(\dst\), \(\mu\), \(\sigma^2\)

In: \(\diffdst\), \(\src\), \(\mu\), \(\sigma^2\), \(\gamma\); Out: \(\diffsrc\), \(\diffgamma\)

Not supported

use_shift

In: \(\src\), \(\beta\); Out: \(\dst\)

In: \(\src\), \(\beta\); Out: \(\dst\), \(\mu\), \(\sigma^2\)

In: \(\diffdst\), \(\src\), \(\mu\), \(\sigma^2\), \(\beta\); Out: \(\diffsrc\), \(\diffbeta\)

Not supported

use_scale | use_shift

In: \(\src\), \(\gamma\), \(\beta\); Out: \(\dst\)

In: \(\src\), \(\gamma\), \(\beta\); Out: \(\dst\), \(\mu\), \(\sigma^2\)

In: \(\diffdst\), \(\src\), \(\mu\), \(\sigma^2\), \(\gamma\), \(\beta\); Out: \(\diffsrc\), \(\diffgamma\), \(\diffbeta\)

Not supported

use_global_stats | use_scale | use_shift

In: \(\src\), \(\mu\), \(\sigma^2\), \(\gamma\), \(\beta\); Out: \(\dst\)

In: \(\src\), \(\mu\), \(\sigma^2\), \(\gamma\), \(\beta\); Out: \(\dst\)

In: \(\diffdst\), \(\src\), \(\mu\), \(\sigma^2\), \(\gamma\), \(\beta\); Out: \(\diffsrc\), \(\diffgamma\), \(\diffbeta\)

Not supported

flags | fuse_norm_relu

In: same as with flags; Out: same as with flags

In: same as with flags; Out: same as with flags, workspace

In: same as with flags, workspace; Out: same as with flags

Same as for backward if flags do not contain use_scale or use_shift; not supported otherwise

When executed, the inputs and outputs should be mapped to an execution argument index as specified by the following table.

Primitive input/output

Execution argument index

\(\src\)

DNNL_ARG_SRC

\(\gamma\)

DNNL_ARG_SCALE

\(\beta\)

DNNL_ARG_SHIFT

mean (\(\mu\))

DNNL_ARG_MEAN

variance (\(\sigma\))

DNNL_ARG_VARIANCE

\(\dst\)

DNNL_ARG_DST

workspace

DNNL_ARG_WORKSPACE

\(\diffdst\)

DNNL_ARG_DIFF_DST

\(\diffsrc\)

DNNL_ARG_DIFF_SRC

\(\diffgamma\)

DNNL_ARG_DIFF_SCALE

\(\diffbeta\)

DNNL_ARG_DIFF_SHIFT

Operation Details#

  1. For forward propagation, the mean and variance might be either computed at runtime (in which case they are outputs of the primitive) or provided by a user (in which case they are inputs). In the latter case, a user must set the use_global_stats flag. For the backward propagation, the mean and variance are always input parameters.

  2. The memory format and data type for src and dst are assumed to be the same, and in the API they are typically referred to as data (e.g., see data_desc in dnnl::batch_normalization_forward::primitive_desc). The same is true for diff_src and diff_dst. The corresponding memory descriptors are referred to as diff_data_desc.

  3. Both forward and backward propagation support in-place operations, meaning that \(\src\) can be used as input and output for forward propagation, and \(\diffdst\) can be used as input and output for backward propagation. In case of an in-place operation, the original data will be overwritten. Note, however, that backward propagation requires original \(\src\), hence the corresponding forward propagation should not be performed in-place.

  4. As mentioned above, the batch normalization primitive can be fused with ReLU activation even in the training mode. In this case, on the forward propagation the primitive has one additional output, workspace, that should be passed during the backward propagation.

Data Types Support#

The operation supports the following combinations of data types.

Note

Here we abbreviate data types names for readability. For example, dnnl::memory::data_type::f32 is abbreviated to f32.

Propagation

Source / Destination

Mean / Variance / Scale / Shift

forward / backward

f32, bf16

f32

forward

f16

f32

forward

s8

f32

Data Representation#

Source, Destination, and Their Gradients#

Like other CNN primitives, the batch normalization primitive expects data to be \(N \times C \times SP_n \times \cdots \times SP_0\) tensor.

The batch normalization primitive is optimized for the following memory formats:

Spatial

Logical tensor

Implementations optimized for memory formats

0D

NC

nc (ab)

1D

NCW

ncw (abc), nwc (acb), optimized

2D

NCHW

nchw (abcd), nhwc (acdb), optimized

3D

NCDHW

ncdhw (abcde), ndhwc (acdeb), optimized

Here optimized means the format chosen by the preceding compute-intensive primitive.

Statistics Tensors#

The mean (\(\mu\)) and variance (\(\sigma^2\)) are separate 1D tensors of size \(C\).

The format of the corresponding memory object must be x (a).

If used, the scale (\(\gamma\)) and shift (\(\beta\)) are combined in a single 2D tensor of shape \(2 \times C\).

The format of the corresponding memory object must be nc (ab).

Post-ops and Attributes#

Propagation

Type

Operation

Description

forward

post-op

eltwise

Applies an eltwise operation to the output.

Note

Using ReLU as a post-op does not produce additional output in the workspace that is required to compute backward propagation correctly. Hence, one should use the fuse_norm_relu flag for training.

API#

struct dnnl::batch_normalization_forward : public dnnl::primitive#

Batch normalization forward propagation primitive.

Public Functions

batch_normalization_forward()#

Default constructor. Produces an empty object.

batch_normalization_forward(const primitive_desc &pd)#

Constructs a batch normalization forward propagation primitive.

Parameters

pd – Primitive descriptor for a batch normalization forward propagation primitive.

struct primitive_desc : public dnnl::primitive_desc#

Primitive descriptor for a batch normalization forward propagation primitive.

Public Functions

primitive_desc() = default#

Default constructor. Produces an empty object.

primitive_desc(const engine &aengine, prop_kind aprop_kind, const memory::desc &src_desc, const memory::desc &dst_desc, float epsilon, normalization_flags flags, const primitive_attr &attr = default_attr(), bool allow_empty = false)#

Constructs a primitive descriptor for a batch normalization forward propagation primitive.

Note

In-place operation is supported: the dst can refer to the same memory as the src.

Parameters
  • aengine – Engine to use.

  • aprop_kind – Propagation kind. Possible values are dnnl::prop_kind::forward_training and dnnl::prop_kind::forward_inference.

  • src_desc – Source memory descriptor.

  • dst_desc – Destination memory descriptor.

  • epsilon – Batch normalization epsilon parameter.

  • flags – Batch normalization flags (dnnl::normalization_flags).

  • attr – Primitive attributes to use. Attributes are optional and default to empty attributes.

  • allow_empty – A flag signifying whether construction is allowed to fail without throwing an exception. In this case an empty object will be produced. This flag is optional and defaults to false.

memory::desc src_desc() const#

Returns a source memory descriptor.

Returns

Source memory descriptor.

Returns

A zero memory descriptor if the primitive does not have a source parameter.

memory::desc dst_desc() const#

Returns a destination memory descriptor.

Returns

Destination memory descriptor.

Returns

A zero memory descriptor if the primitive does not have a destination parameter.

memory::desc weights_desc() const#

Returns a weights memory descriptor.

Returns

Weights memory descriptor.

Returns

A zero memory descriptor if the primitive does not have a weights parameter.

memory::desc workspace_desc() const#

Returns the workspace memory descriptor.

Returns

Workspace memory descriptor.

Returns

A zero memory descriptor if the primitive does not require workspace parameter.

memory::desc mean_desc() const#

Returns memory descriptor for mean.

Returns

Memory descriptor for mean.

memory::desc variance_desc() const#

Returns memory descriptor for variance.

Returns

Memory descriptor for variance.

dnnl::prop_kind get_prop_kind() const#

Returns a propagation kind.

Returns

A propagation kind.

Returns

dnnl::prop_kind::undef if the primitive does not have a propagation parameter.

float get_epsilon() const#

Returns an epsilon.

Returns

An epsilon.

Returns

Zero if the primitive does not have an epsilon parameter.

normalization_flags get_flags() const#

Returns normalization flags.

Returns

Normalization flags.

struct dnnl::batch_normalization_backward : public dnnl::primitive#

Batch normalization backward propagation primitive.

Public Functions

batch_normalization_backward()#

Default constructor. Produces an empty object.

batch_normalization_backward(const primitive_desc &pd)#

Constructs a batch normalization backward propagation primitive.

Parameters

pd – Primitive descriptor for a batch normalization backward propagation primitive.

struct primitive_desc : public dnnl::primitive_desc#

Primitive descriptor for a batch normalization backward propagation primitive.

Public Functions

primitive_desc() = default#

Default constructor. Produces an empty object.

primitive_desc(const engine &aengine, prop_kind aprop_kind, const memory::desc &diff_src_desc, const memory::desc &diff_dst_desc, const memory::desc &src_desc, float epsilon, normalization_flags flags, const batch_normalization_forward::primitive_desc &hint_fwd_pd, const primitive_attr &attr = default_attr(), bool allow_empty = false)#

Constructs a primitive descriptor for a batch normalization backward propagation primitive.

Parameters
  • aengine – Engine to use.

  • aprop_kind – Propagation kind. Possible values are dnnl::prop_kind::backward_data and dnnl::prop_kind::backward (diffs for all parameters are computed in this case).

  • diff_src_desc – Diff source memory descriptor.

  • diff_dst_desc – Diff destination memory descriptor.

  • src_desc – Source memory descriptor.

  • epsilon – Batch normalization epsilon parameter.

  • flags – Batch normalization flags (dnnl::normalization_flags).

  • hint_fwd_pd – Primitive descriptor for a batch normalization forward propagation primitive. It is used as a hint for deciding which memory format to use.

  • attr – Primitive attributes to use. Attributes are optional and default to empty attributes.

  • allow_empty – A flag signifying whether construction is allowed to fail without throwing an exception. In this case an empty object will be produced. This flag is optional and defaults to false.

memory::desc src_desc() const#

Returns a source memory descriptor.

Returns

Source memory descriptor.

Returns

A zero memory descriptor if the primitive does not have a source parameter.

memory::desc weights_desc() const#

Returns a weights memory descriptor.

Returns

Weights memory descriptor.

Returns

A zero memory descriptor if the primitive does not have a weights parameter.

memory::desc dst_desc() const#

Returns a destination memory descriptor.

Returns

Destination memory descriptor.

Returns

A zero memory descriptor if the primitive does not have a destination parameter.

memory::desc diff_src_desc() const#

Returns a diff source memory descriptor.

Returns

Diff source memory descriptor.

Returns

A zero memory descriptor if the primitive does not have a diff source memory with.

memory::desc diff_dst_desc() const#

Returns a diff destination memory descriptor.

Returns

Diff destination memory descriptor.

Returns

A zero memory descriptor if the primitive does not have a diff destination parameter.

memory::desc diff_weights_desc()#

Returns a diff weights memory descriptor.

Returns

Diff weights memory descriptor.

Returns

A zero memory descriptor if the primitive does not have a diff weights parameter.

memory::desc mean_desc() const#

Returns memory descriptor for mean.

Returns

Memory descriptor for mean.

memory::desc variance_desc() const#

Returns memory descriptor for variance.

Returns

Memory descriptor for variance.

memory::desc workspace_desc() const#

Returns the workspace memory descriptor.

Returns

Workspace memory descriptor.

Returns

A zero memory descriptor if the primitive does not require workspace parameter.

dnnl::prop_kind get_prop_kind() const#

Returns a propagation kind.

Returns

A propagation kind.

Returns

dnnl::prop_kind::undef if the primitive does not have a propagation parameter.

float get_epsilon() const#

Returns an epsilon.

Returns

An epsilon.

Returns

Zero if the primitive does not have an epsilon parameter.

normalization_flags get_flags() const#

Returns normalization flags.

Returns

Normalization flags.