Local Response Normalization

The LRN primitive performs a forward or backward local response normalization operation defined by the following formulas. Variable names follow the standard Conventions.

Forward

LRN across channels:

\[\dst(n, c, h, w) = \left\{k + \frac{\alpha}{n_{l}} \sum\limits_{i=-(n_{l}-1)/2}^{(n_{l}+1)/2-1} (\src(n, c+i, h, w))^2 \right\}^{-\beta} \cdot \src(n, c, h, w),\]

LRN within channel:

\[\dst(n, c, h, w) = \left\{k + \frac{\alpha}{n_{l}} \sum\limits_{i=-(n_{l}-1)/2}^{(n_{l}+1)/2-1} \sum\limits_{j=-(n_{l}-1)/2}^{(n_{l}+1)/2-1} (\src(n, c, h+i, w+j))^2 \right\}^{-\beta} \cdot \src(n, c, h, w),\]

where \(n_{l}\) is the local_size. Formulas are provided for 2D spatial data case.

Backward

The backward propagation computes \(\diffsrc(n, c, h, w)\), based on \(\diffdst(n, c, h, w)\) and \(\src(n, c, h, w)\).

Execution Arguments

When executed, the inputs and outputs should be mapped to an execution argument index as specified by the following table.

Primitive input/output

Execution argument index

\(\src\)

DNNL_ARG_SRC

\(\dst\)

DNNL_ARG_DST

workspace

DNNL_ARG_WORKSPACE

\(\diffsrc\)

DNNL_ARG_DIFF_SRC

\(\diffdst\)

DNNL_ARG_DIFF_DST

Operation Details

  1. During training, LRN might or might not require a workspace on forward and backward passes. The behavior is implementation specific. Optimized implementations typically require a workspace and use it to save some intermediate results from the forward pass that accelerate computations on the backward pass. To check whether a workspace is required, query the LRN primitive descriptor for the workspace. Success indicates that the workspace is required and its description will be returned.

  2. The memory format and data type for src and dst are assumed to be the same, and in the API are typically referred to as data (e.g., see data_desc in dnnl::lrn_forward::desc::desc()). The same holds for diff_src and diff_dst. The corresponding memory descriptors are referred to as diff_data_desc.

Data Type Support

The LRN primitive supports the following combinations of data types.

Note

Here we abbreviate data types names for readability. For example, dnnl::memory::data_type::f32 is abbreviated to f32.

Propagation

Source / Destination

forward / backward

f32, bf16

forward

f16

Data Representation

Source, Destination, and Their Gradients

Like most other primitives, the LRN primitive expects the following tensors:

Spatial

Source / Destination

0D

\(N \times C\)

1D

\(N \times C \times W\)

2D

\(N \times C \times H \times W\)

3D

\(N \times C \times D \times H \times W\)

The LRN primitive is optimized for the following memory formats:

Spatial

Logical tensor

Implementations optimized for memory formats

2D

NCHW

nchw (abcd), nhwc (acdb), optimized

Here optimized means the format chosen by the preceding compute-intensive primitive.

Post-ops and Attributes

The LRN primitive does not support any post-ops or attributes.

API

struct dnnl::lrn_forward : public dnnl::primitive

Local response normalization (LRN) forward propagation primitive.

Public Functions

lrn_forward()

Default constructor. Produces an empty object.

lrn_forward(const primitive_desc &pd)

Constructs an LRN forward propagation primitive.

Parameters

pd – Primitive descriptor for an LRN forward propagation primitive.

struct desc

Descriptor for an LRN forward propagation primitive.

Public Functions

desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &data_desc, memory::dim local_size, float alpha, float beta, float k = 1.f)

Constructs a descriptor for a LRN forward propagation primitive.

Parameters
struct primitive_desc : public dnnl::primitive_desc

Primitive descriptor for an LRN forward propagation primitive.

Public Functions

primitive_desc()

Default constructor. Produces an empty object.

primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty = false)

Constructs a primitive descriptor for an LRN forward propagation primitive.

Parameters
  • adesc – Descriptor for an LRN forward propagation primitive.

  • aengine – Engine to use.

  • allow_empty – A flag signifying whether construction is allowed to fail without throwing an exception. In this case an empty object will be produced. This flag is optional and defaults to false.

primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty = false)

Constructs a primitive descriptor for an LRN forward propagation primitive.

Parameters
  • adesc – Descriptor for an LRN forward propagation primitive.

  • aengine – Engine to use.

  • attr – Primitive attributes to use.

  • allow_empty – A flag signifying whether construction is allowed to fail without throwing an exception. In this case an empty object will be produced. This flag is optional and defaults to false.

memory::desc src_desc() const

Returns a source memory descriptor.

Returns

Source memory descriptor.

Returns

A zero memory descriptor if the primitive does not have a source parameter.

memory::desc dst_desc() const

Returns a destination memory descriptor.

Returns

Destination memory descriptor.

Returns

A zero memory descriptor if the primitive does not have a destination parameter.

memory::desc workspace_desc() const

Returns the workspace memory descriptor.

Returns

Workspace memory descriptor.

Returns

A zero memory descriptor if the primitive does not require workspace parameter.

struct dnnl::lrn_backward : public dnnl::primitive

Local response normalization (LRN) backward propagation primitive.

Public Functions

lrn_backward()

Default constructor. Produces an empty object.

lrn_backward(const primitive_desc &pd)

Constructs an LRN backward propagation primitive.

Parameters

pd – Primitive descriptor for an LRN backward propagation primitive.

struct desc

Descriptor for an LRN backward propagation primitive.

Public Functions

desc(algorithm aalgorithm, const memory::desc &data_desc, const memory::desc &diff_data_desc, memory::dim local_size, float alpha, float beta, float k = 1.f)

Constructs a descriptor for an LRN backward propagation primitive.

Parameters
  • aalgorithm – LRN algorithm kind: either dnnl::algorithm::lrn_across_channels, or dnnl::algorithm::lrn_within_channel.

  • diff_data_desc – Diff source and diff destination memory descriptor.

  • data_desc – Source memory descriptor.

  • local_size – Regularization local size.

  • alpha – The alpha regularization parameter.

  • beta – The beta regularization parameter.

  • k – The k regularization parameter.

struct primitive_desc : public dnnl::primitive_desc

Primitive descriptor for an LRN backward propagation primitive.

Public Functions

primitive_desc()

Default constructor. Produces an empty object.

primitive_desc(const desc &adesc, const engine &aengine, const lrn_forward::primitive_desc &hint_fwd_pd, bool allow_empty = false)

Constructs a primitive descriptor for an LRN backward propagation primitive.

Parameters
  • adesc – Descriptor for an LRN backward propagation primitive.

  • aengine – Engine to use.

  • hint_fwd_pd – Primitive descriptor for an LRN forward propagation primitive. It is used as a hint for deciding which memory format to use.

  • allow_empty – A flag signifying whether construction is allowed to fail without throwing an exception. In this case an empty object will be produced. This flag is optional and defaults to false.

primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const lrn_forward::primitive_desc &hint_fwd_pd, bool allow_empty = false)

Constructs a primitive descriptor for an LRN backward propagation primitive.

Parameters
  • adesc – Descriptor for an LRN backward propagation primitive.

  • attr – Primitive attributes to use.

  • aengine – Engine to use.

  • hint_fwd_pd – Primitive descriptor for an LRN forward propagation primitive. It is used as a hint for deciding which memory format to use.

  • allow_empty – A flag signifying whether construction is allowed to fail without throwing an exception. In this case an empty object will be produced. This flag is optional and defaults to false.

memory::desc diff_src_desc() const

Returns a source memory descriptor.

Returns

Source memory descriptor.

Returns

A zero memory descriptor if the primitive does not have a source parameter.

memory::desc diff_dst_desc() const

Returns a diff destination memory descriptor.

Returns

Diff destination memory descriptor.

Returns

A zero memory descriptor if the primitive does not have a diff destination parameter.

memory::desc workspace_desc() const

Returns the workspace memory descriptor.

Returns

Workspace memory descriptor.

Returns

A zero memory descriptor if the primitive does not require workspace parameter.