Local Response Normalization¶
The LRN primitive performs a forward or backward local response normalization operation defined by the following formulas. Variable names follow the standard Conventions.
Forward¶
LRN across channels:
LRN within channel:
where \(n_{l}\) is the local_size
. Formulas are provided for 2D
spatial data case.
Backward¶
The backward propagation computes \(\diffsrc(n, c, h, w)\), based on \(\diffdst(n, c, h, w)\) and \(\src(n, c, h, w)\).
Execution Arguments¶
When executed, the inputs and outputs should be mapped to an execution argument index as specified by the following table.
Primitive input/output |
Execution argument index |
---|---|
\(\src\) |
|
\(\dst\) |
|
workspace |
|
\(\diffsrc\) |
|
\(\diffdst\) |
Operation Details¶
During training, LRN might or might not require a workspace on forward and backward passes. The behavior is implementation specific. Optimized implementations typically require a workspace and use it to save some intermediate results from the forward pass that accelerate computations on the backward pass. To check whether a workspace is required, query the LRN primitive descriptor for the workspace. Success indicates that the workspace is required and its description will be returned.
The memory format and data type for
src
anddst
are assumed to be the same, and in the API are typically referred to asdata
(e.g., seedata_desc
in dnnl::lrn_forward::desc::desc()). The same holds fordiff_src
anddiff_dst
. The corresponding memory descriptors are referred to asdiff_data_desc
.
Data Type Support¶
The LRN primitive supports the following combinations of data types.
Note
Here we abbreviate data types names for readability. For example, dnnl::memory::data_type::f32
is
abbreviated to f32
.
Propagation |
Source / Destination |
---|---|
forward / backward |
|
forward |
Data Representation¶
Source, Destination, and Their Gradients¶
Like most other primitives, the LRN primitive expects the following tensors:
Spatial |
Source / Destination |
---|---|
0D |
\(N \times C\) |
1D |
\(N \times C \times W\) |
2D |
\(N \times C \times H \times W\) |
3D |
\(N \times C \times D \times H \times W\) |
The LRN primitive is optimized for the following memory formats:
Spatial |
Logical tensor |
Implementations optimized for memory formats |
---|---|---|
2D |
NCHW |
Here optimized means the format chosen by the preceding compute-intensive primitive.
Post-ops and Attributes¶
The LRN primitive does not support any post-ops or attributes.
API¶
-
struct dnnl::lrn_forward : public dnnl::primitive¶
Local response normalization (LRN) forward propagation primitive.
Public Functions
-
lrn_forward()¶
Default constructor. Produces an empty object.
-
lrn_forward(const primitive_desc &pd)¶
Constructs an LRN forward propagation primitive.
- Parameters
pd – Primitive descriptor for an LRN forward propagation primitive.
-
struct desc¶
Descriptor for an LRN forward propagation primitive.
Public Functions
-
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &data_desc, memory::dim local_size, float alpha, float beta, float k = 1.f)¶
Constructs a descriptor for a LRN forward propagation primitive.
- Parameters
aprop_kind – Propagation kind. Possible values are dnnl::prop_kind::forward_training, and dnnl::prop_kind::forward_inference.
aalgorithm – LRN algorithm kind: either dnnl::algorithm::lrn_across_channels, or dnnl::algorithm::lrn_within_channel.
data_desc – Source and destination memory descriptors.
local_size – Regularization local size.
alpha – The alpha regularization parameter.
beta – The beta regularization parameter.
k – The k regularization parameter.
-
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &data_desc, memory::dim local_size, float alpha, float beta, float k = 1.f)¶
-
struct primitive_desc : public dnnl::primitive_desc¶
Primitive descriptor for an LRN forward propagation primitive.
Public Functions
-
primitive_desc()¶
Default constructor. Produces an empty object.
-
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty = false)¶
Constructs a primitive descriptor for an LRN forward propagation primitive.
- Parameters
adesc – Descriptor for an LRN forward propagation primitive.
aengine – Engine to use.
allow_empty – A flag signifying whether construction is allowed to fail without throwing an exception. In this case an empty object will be produced. This flag is optional and defaults to false.
-
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty = false)¶
Constructs a primitive descriptor for an LRN forward propagation primitive.
- Parameters
adesc – Descriptor for an LRN forward propagation primitive.
aengine – Engine to use.
attr – Primitive attributes to use.
allow_empty – A flag signifying whether construction is allowed to fail without throwing an exception. In this case an empty object will be produced. This flag is optional and defaults to false.
-
memory::desc src_desc() const¶
Returns a source memory descriptor.
- Returns
Source memory descriptor.
- Returns
A zero memory descriptor if the primitive does not have a source parameter.
-
primitive_desc()¶
-
lrn_forward()¶
-
struct dnnl::lrn_backward : public dnnl::primitive¶
Local response normalization (LRN) backward propagation primitive.
Public Functions
-
lrn_backward()¶
Default constructor. Produces an empty object.
-
lrn_backward(const primitive_desc &pd)¶
Constructs an LRN backward propagation primitive.
- Parameters
pd – Primitive descriptor for an LRN backward propagation primitive.
-
struct desc¶
Descriptor for an LRN backward propagation primitive.
Public Functions
-
desc(algorithm aalgorithm, const memory::desc &data_desc, const memory::desc &diff_data_desc, memory::dim local_size, float alpha, float beta, float k = 1.f)¶
Constructs a descriptor for an LRN backward propagation primitive.
- Parameters
aalgorithm – LRN algorithm kind: either dnnl::algorithm::lrn_across_channels, or dnnl::algorithm::lrn_within_channel.
diff_data_desc – Diff source and diff destination memory descriptor.
data_desc – Source memory descriptor.
local_size – Regularization local size.
alpha – The alpha regularization parameter.
beta – The beta regularization parameter.
k – The k regularization parameter.
-
desc(algorithm aalgorithm, const memory::desc &data_desc, const memory::desc &diff_data_desc, memory::dim local_size, float alpha, float beta, float k = 1.f)¶
-
struct primitive_desc : public dnnl::primitive_desc¶
Primitive descriptor for an LRN backward propagation primitive.
Public Functions
-
primitive_desc()¶
Default constructor. Produces an empty object.
-
primitive_desc(const desc &adesc, const engine &aengine, const lrn_forward::primitive_desc &hint_fwd_pd, bool allow_empty = false)¶
Constructs a primitive descriptor for an LRN backward propagation primitive.
- Parameters
adesc – Descriptor for an LRN backward propagation primitive.
aengine – Engine to use.
hint_fwd_pd – Primitive descriptor for an LRN forward propagation primitive. It is used as a hint for deciding which memory format to use.
allow_empty – A flag signifying whether construction is allowed to fail without throwing an exception. In this case an empty object will be produced. This flag is optional and defaults to false.
-
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const lrn_forward::primitive_desc &hint_fwd_pd, bool allow_empty = false)¶
Constructs a primitive descriptor for an LRN backward propagation primitive.
- Parameters
adesc – Descriptor for an LRN backward propagation primitive.
attr – Primitive attributes to use.
aengine – Engine to use.
hint_fwd_pd – Primitive descriptor for an LRN forward propagation primitive. It is used as a hint for deciding which memory format to use.
allow_empty – A flag signifying whether construction is allowed to fail without throwing an exception. In this case an empty object will be produced. This flag is optional and defaults to false.
-
memory::desc diff_src_desc() const¶
Returns a source memory descriptor.
- Returns
Source memory descriptor.
- Returns
A zero memory descriptor if the primitive does not have a source parameter.
-
primitive_desc()¶
-
lrn_backward()¶