Layer normalization#

The layer normalization primitive performs a forward or backward layer normalization operation on a 2-5D data tensor.

The layer normalization operation performs normalization over the last logical axis of the data tensor and is defined by the following formulas. We show formulas only for 3D data, which are straightforward to generalize to cases of higher dimensions. Variable names follow the standard Conventions.

Forward#

\[\dst(t, n, c) = \gamma(c) \cdot \frac{\src(t, n, c) - \mu(t, n)} {\sqrt{\sigma^2(t, n) + \varepsilon}} + \beta(c),\]

where

  • \(\gamma(c), \beta(c)\) are optional scale and shift for a channel (see the use_scale and use_shift flag),

  • \(\mu(t, n), \sigma^2(t, n)\) are mean and variance (see use_global_stats flag), and

  • \(\varepsilon\) is a constant to improve numerical stability.

Mean and variance are computed at runtime or provided by a user. When mean and variance are computed at runtime, the following formulas are used:

  • \(\mu(t, n) = \frac{1}{C} \sum\limits_{c} \src(t, n, c)_{}\),

  • \(\sigma^2(t, n) = \frac{1}{C} \sum\limits_{c} {}_{} (\src(t, n, c) - \mu(t, n))^2\).

The \(\gamma(c)\) and \(\beta(c)\) tensors are considered learnable.

Difference Between Forward Training and Forward Inference#

If mean and variance are computed at runtime (i.e., use_global_stats is not set), they become outputs for the propagation kind forward_training (because they would be required during the backward propagation). Data layout for mean and variance must be specified during initialization of the layer normalization descriptor by passing the memory descriptor for statistics (e.g., by passing stat_desc in dnnl::layer_normalization_forward::primtive_desc). Mean and variance are not exposed for the propagation kind forward_inference.

Backward#

The backward propagation computes \(\diffsrc(t, n, c)\), \(\diffgamma(c)^*\), and \(\diffbeta(c)^*\) based on \(\diffdst(t, n, c)\), \(src(t, n, c)\), \(\mu(t, n)\), \(\sigma^2(t, n)\), \(\gamma(c) ^*\), and \(\beta(c) ^*\).

The tensors marked with an asterisk are used only when the primitive is configured to use \(\gamma(c)\), and \(\beta(c)\) (i.e. use_scale and use_shift is set).

Execution Arguments#

Depending on the flags and propagation kind, the layer normalization primitive requires different inputs and outputs. For clarity, a summary is shown below.

forward_inference

forward_training

backward

backward_data

none

In: \(\src\) Out: \(\dst\)

In: \(\src\) Out: \(\dst\), \(\mu\), \(\sigma^2\)

In: \(\diffdst\), \(\src\), \(\mu\), \(\sigma^2\) Out: \(\diffsrc\)

Same as for backward

use_global_stats

In: \(\src\), \(\mu\), \(\sigma^2\) Out: \(\dst\)

In: \(\src\), \(\mu\), \(\sigma^2\) Out: \(\dst\)

In: \(\diffdst\), \(\src\), \(\mu\), \(\sigma^2\) Out: \(\diffsrc\)

Same as for backward

use_scale

In: \(\src\), \(\gamma\) Out: \(\dst\)

In: \(\src\), \(\gamma\) Out: \(\dst\), \(\mu\), \(\sigma^2\)

In: \(\diffdst\), \(\src\), \(\mu\), \(\sigma^2\), \(\gamma\) Out: \(\diffsrc\), \(\diffgamma\)

Not supported

use_shift

In: \(\src\), \(\beta\) Out: \(\dst\)

In: \(\src\), \(\beta\) Out: \(\dst\), \(\mu\), \(\sigma^2\)

In: \(\diffdst\), \(\src\), \(\mu\), \(\sigma^2\), \(\beta\) Out: \(\diffsrc\), \(\diffbeta\)

Not supported

use_scale | use_shift

In: \(\src\), \(\gamma\), \(\beta\) Out: \(\dst\)

In: \(\src\), \(\gamma\), \(\beta\) Out: \(\dst\), \(\mu\), \(\sigma^2\)

In: \(\diffdst\), \(\src\), \(\mu\), \(\sigma^2\), \(\gamma\), \(\beta\) Out: \(\diffsrc\), \(\diffgamma\), \(\diffbeta\)

Not supported

use_global_stats | use_scale | use_shift

In: \(\src\), \(\mu\), \(\sigma^2\), \(\gamma\), \(\beta\) Out: \(\dst\)

In: \(\src\), \(\mu\), \(\sigma^2\), \(\gamma\), \(\beta\) Out: \(\dst\)

In: \(\diffdst\), \(\src\), \(\mu\), \(\sigma^2\), \(\gamma\), \(\beta\) Out: \(\diffsrc\), \(\diffgamma\), \(\diffbeta\)

Not supported

When executed, the inputs and outputs should be mapped to an execution argument index as specified by the following table.

Primitive input/output

Execution argument index

\(\src\)

DNNL_ARG_SRC

\(\gamma, \beta\)

DNNL_ARG_SCALE

\(\beta\)

DNNL_ARG_SHIFT

mean (\(\mu\))

DNNL_ARG_MEAN

variance (\(\sigma\))

DNNL_ARG_VARIANCE

\(\dst\)

DNNL_ARG_DST

\(\diffdst\)

DNNL_ARG_DIFF_DST

\(\diffsrc\)

DNNL_ARG_DIFF_SRC

\(\diffgamma\)

DNNL_ARG_DIFF_SCALE

\(\diffbeta\)

DNNL_ARG_DIFF_SHIFT

Operation Details#

  1. The different flavors of the primitive are partially controlled by the flags parameter that is passed to the primitive descriptor initialization function (e.g., dnnl::layer_normalization_forward::primtive_desc). Multiple flags can be combined using the bitwise OR operator (|).

  2. For forward propagation, the mean and variance might be either computed at runtime (in which case they are outputs of the primitive) or provided by a user (in which case they are inputs). In the latter case, a user must set the use_global_stats flag. For the backward propagation, the mean and variance are always input parameters.

  3. Both forward and backward propagation support in-place operations, meaning that \(\src\) can be used as input and output for forward propagation, and \(\diffdst\) can be used as input and output for backward propagation. In case of an in-place operation, the original data will be overwritten. Note, however, that backward propagation requires original \(\src\), hence the corresponding forward propagation should not be performed in-place.

Data Types Support#

The layer normalization supports the following combinations of data types.

Note

Here we abbreviate data types names for readability. For example, dnnl::memory::data_type::f32 is abbreviated to f32.

Propagation

Source / Destination

Mean / Variance / Scale / Shift

forward / backward

f32

f32

forward

f16

f32

Data Representation#

Mean and Variance#

The mean (\(\mu\)) and variance (\(\sigma^2\)) are separate tensors with number of dimensions equal to (\(data\_ndims - 1\)) and size \((data\_dim[0], data\_dim[1], ..., data\_dim[ndims - 2])\).

The corresponding memory object can have an arbitrary memory format. Unless mean and variance are computed at runtime and not exposed (i.e., propagation kind is forward_inference and use_global_stats is not set), the user should provide a memory descriptor for statistics when initializing the layer normalization descriptor. For best performance, it is advised to use the memory format that follows the data memory format; i.e., if the data format is tnc, the best performance can be expected for statistics with the tn format and suboptimal for statistics with the nt format.

Scale and Shift#

If used, the scale (\(\gamma\)) and shift (\(\beta\)) are combined in a single 2D tensor of shape \(2 \times C\).

The format of the corresponding memory object must be nc (ab).

Source, Destination, and Their Gradients#

The layer normalization primitive works with an arbitrary data tensor; however, it was designed for RNN data tensors (i.e., nc, tnc, ldnc). Unlike CNN data tensors, RNN data tensors have a single feature dimension. Layer normalization performs normalization over the last logical dimension (feature dimension for RNN tensors) across non-feature dimensions.

The layer normalization primitive is optimized for the following memory formats:

Logical tensor

Implementations optimized for memory formats

NC

nc (ab)

TNC

tnc (abc), ntc (bac)

LDNC

ldnc (abcd)

API#

struct layer_normalization_forward : public dnnl::primitive#

Layer normalization forward propagation primitive.

Public Functions

layer_normalization_forward()#

Default constructor. Produces an empty object.

layer_normalization_forward(const primitive_desc &pd)#

Constructs a layer normalization forward propagation primitive.

Parameters:

pd – Primitive descriptor for a layer normalization forward propagation primitive.

struct primitive_desc : public dnnl::primitive_desc#

Primitive descriptor for a layer normalization forward propagation primitive.

Public Functions

primitive_desc() = default#

Default constructor. Produces an empty object.

primitive_desc(const engine &aengine, prop_kind aprop_kind, const memory::desc &src_desc, const memory::desc &dst_desc, const memory::desc &stat_desc, float epsilon, normalization_flags flags, const primitive_attr &attr = default_attr(), bool allow_empty = false)#

Constructs a primitive descriptor for a layer normalization forward propagation primitive.

Parameters:
  • aengine – Engine to use.

  • aprop_kind – Propagation kind. Possible values are dnnl::prop_kind::forward_training, and dnnl::prop_kind::forward_inference.

  • src_desc – Source memory descriptor.

  • dst_desc – Destination memory descriptor.

  • stat_desc – Statistics memory descriptors.

  • epsilon – Layer normalization epsilon parameter.

  • flags – Layer normalization flags (dnnl::normalization_flags).

  • attr – Primitive attributes to use. Attributes are optional and default to empty attributes.

  • allow_empty – A flag signifying whether construction is allowed to fail without throwing an exception. In this case an empty object will be produced. This flag is optional and defaults to false.

primitive_desc(const engine &aengine, prop_kind aprop_kind, const memory::desc &src_desc, const memory::desc &dst_desc, float epsilon, normalization_flags flags, const primitive_attr &attr = default_attr(), bool allow_empty = false)#

Constructs a primitive descriptor for a layer normalization forward propagation primitive.

Parameters:
  • aengine – Engine to use.

  • aprop_kind – Propagation kind. Possible values are dnnl::prop_kind::forward_training, and dnnl::prop_kind::forward_inference.

  • src_desc – Source memory descriptor.

  • dst_desc – Destination memory descriptor.

  • epsilon – Layer normalization epsilon parameter.

  • flags – Layer normalization flags (dnnl::normalization_flags).

  • attr – Primitive attributes to use. Attributes are optional and default to empty attributes.

  • allow_empty – A flag signifying whether construction is allowed to fail without throwing an exception. In this case an empty object will be produced. This flag is optional and defaults to false.

memory::desc src_desc() const#

Returns a source memory descriptor.

Returns:

Source memory descriptor.

Returns:

A zero memory descriptor if the primitive does not have a source parameter.

memory::desc dst_desc() const#

Returns a destination memory descriptor.

Returns:

Destination memory descriptor.

Returns:

A zero memory descriptor if the primitive does not have a destination parameter.

memory::desc weights_desc() const#

Returns a weights memory descriptor.

Returns:

Weights memory descriptor.

Returns:

A zero memory descriptor if the primitive does not have a weights parameter.

memory::desc workspace_desc() const#

Returns the workspace memory descriptor.

Returns:

Workspace memory descriptor.

Returns:

A zero memory descriptor if the primitive does not require workspace parameter.

memory::desc mean_desc() const#

Returns memory descriptor for mean.

Returns:

Memory descriptor for mean.

memory::desc variance_desc() const#

Returns memory descriptor for variance.

Returns:

Memory descriptor for variance.

dnnl::prop_kind get_prop_kind() const#

Returns a propagation kind.

Returns:

A propagation kind.

Returns:

dnnl::prop_kind::undef if the primitive does not have a propagation parameter.

float get_epsilon() const#

Returns an epsilon.

Returns:

An epsilon.

Returns:

Zero if the primitive does not have an epsilon parameter.

normalization_flags get_flags() const#

Returns normalization flags.

Returns:

Normalization flags.

struct layer_normalization_backward : public dnnl::primitive#

Layer normalization backward propagation primitive.

Public Functions

layer_normalization_backward()#

Default constructor. Produces an empty object.

layer_normalization_backward(const primitive_desc &pd)#

Constructs a layer normalization backward propagation primitive.

Parameters:

pd – Primitive descriptor for a layer normalization backward propagation primitive.

struct primitive_desc : public dnnl::primitive_desc#

Primitive descriptor for a layer normalization backward propagation primitive.

Public Functions

primitive_desc() = default#

Default constructor. Produces an empty object.

primitive_desc(const engine &aengine, prop_kind aprop_kind, const memory::desc &diff_src_desc, const memory::desc &diff_dst_desc, const memory::desc &src_desc, const memory::desc &stat_desc, float epsilon, normalization_flags flags, const layer_normalization_forward::primitive_desc &hint_fwd_pd, const primitive_attr &attr = default_attr(), bool allow_empty = false)#

Constructs a primitive descriptor for a layer normalization backward propagation primitive.

Parameters:
  • aengine – Engine to use.

  • aprop_kind – Propagation kind. Possible values are dnnl::prop_kind::backward_data and dnnl::prop_kind::backward (diffs for all parameters are computed in this case).

  • diff_src_desc – Diff source memory descriptor.

  • diff_dst_desc – Diff destination memory descriptor.

  • src_desc – Source memory descriptor.

  • stat_desc – Statistics memory descriptors.

  • epsilon – Layer normalization epsilon parameter.

  • flags – Layer normalization flags (dnnl::normalization_flags).

  • attr – Primitive attributes to use. Attributes are optional and default to empty attributes.

  • hint_fwd_pd – Primitive descriptor for a layer normalization forward propagation primitive. It is used as a hint for deciding which memory format to use.

  • allow_empty – A flag signifying whether construction is allowed to fail without throwing an exception. In this case an empty object will be produced. This flag is optional and defaults to false.

primitive_desc(const engine &aengine, prop_kind aprop_kind, const memory::desc &diff_src_desc, const memory::desc &diff_dst_desc, const memory::desc &src_desc, float epsilon, normalization_flags flags, const layer_normalization_forward::primitive_desc &hint_fwd_pd, const primitive_attr &attr = default_attr(), bool allow_empty = false)#

Constructs a primitive descriptor for a layer normalization backward propagation primitive.

Parameters:
  • aengine – Engine to use.

  • aprop_kind – Propagation kind. Possible values are dnnl::prop_kind::backward_data and dnnl::prop_kind::backward (diffs for all parameters are computed in this case).

  • diff_src_desc – Diff source memory descriptor.

  • diff_dst_desc – Diff destination memory descriptor.

  • src_desc – Source memory descriptor.

  • epsilon – Layer normalization epsilon parameter.

  • flags – Layer normalization flags (dnnl::normalization_flags).

  • attr – Primitive attributes to use. Attributes are optional and default to empty attributes.

  • hint_fwd_pd – Primitive descriptor for a layer normalization forward propagation primitive. It is used as a hint for deciding which memory format to use.

  • allow_empty – A flag signifying whether construction is allowed to fail without throwing an exception. In this case an empty object will be produced. This flag is optional and defaults to false.

memory::desc src_desc() const#

Returns a source memory descriptor.

Returns:

Source memory descriptor.

Returns:

A zero memory descriptor if the primitive does not have a source parameter.

memory::desc weights_desc() const#

Returns a weights memory descriptor.

Returns:

Weights memory descriptor.

Returns:

A zero memory descriptor if the primitive does not have a weights parameter.

memory::desc dst_desc() const#

Returns a destination memory descriptor.

Returns:

Destination memory descriptor.

Returns:

A zero memory descriptor if the primitive does not have a destination parameter.

memory::desc diff_src_desc() const#

Returns a diff source memory descriptor.

Returns:

Diff source memory descriptor.

Returns:

A zero memory descriptor if the primitive does not have a diff source memory with.

memory::desc diff_dst_desc() const#

Returns a diff destination memory descriptor.

Returns:

Diff destination memory descriptor.

Returns:

A zero memory descriptor if the primitive does not have a diff destination parameter.

memory::desc diff_weights_desc() const#

Returns a diff weights memory descriptor.

Returns:

Diff weights memory descriptor.

Returns:

A zero memory descriptor if the primitive does not have a diff weights parameter.

memory::desc mean_desc() const#

Returns memory descriptor for mean.

Returns:

Memory descriptor for mean.

memory::desc variance_desc() const#

Returns memory descriptor for variance.

Returns:

Memory descriptor for variance.

memory::desc workspace_desc() const#

Returns the workspace memory descriptor.

Returns:

Workspace memory descriptor.

Returns:

A zero memory descriptor if the primitive does not require workspace parameter.

dnnl::prop_kind get_prop_kind() const#

Returns a propagation kind.

Returns:

A propagation kind.

Returns:

dnnl::prop_kind::undef if the primitive does not have a propagation parameter.

float get_epsilon() const#

Returns an epsilon.

Returns:

An epsilon.

Returns:

Zero if the primitive does not have an epsilon parameter.

normalization_flags get_flags() const#

Returns normalization flags.

Returns:

Normalization flags.