Inner Product¶
The inner product primitive (sometimes called fully connected layer) treats each activation in the minibatch as a vector and computes its product with a weights 2D tensor producing a 2D tensor as an output.
Forward¶
Let \(\src\), \(\weights\), \(\bias\) and \(\dst\) be \(N \times IC\), \(OC \times IC\), \(OC\), and \(N \times OC\) tensors, respectively. Variable names follow the standard Conventions. Then:
In cases where the \(\src\) and \(\weights\) tensors have spatial dimensions, they are flattened to 2D. For example, if they are 4D \(N \times IC' \times IH \times IW\) and \(OC \times IC' \times KH \times KW\) tensors, then the formula above is applied with \(IC = IC' \cdot IH \cdot IW\). In such cases, the \(\src\) and \(\weights\) tensors must have equal spatial dimensions (e.g. \(KH = IH\) and \(KW = IW\) for 4D tensors).
Difference Between Forward Training and Forward Inference¶
There is no difference between the forward_training
and forward_inference
propagation kinds.
Backward¶
The backward propagation computes \(\diffsrc\) based on \(\diffdst\) and \(\weights\).
The weights update computes \(\diffweights\) and \(\diffbias\) based on \(\diffdst\) and \(\src\).
Note
The optimized memory formats \(\src\) and \(\weights\) might be different on forward propagation, backward propagation, and weights update.
Execution Arguments¶
When executed, the inputs and outputs should be mapped to an execution argument index as specified by the following table.
Primitive input/output |
Execution argument index |
---|---|
\(\src\) |
|
\(\weights\) |
|
\(\bias\) |
|
\(\dst\) |
|
\(\diffsrc\) |
|
\(\diffweights\) |
|
\(\diffbias\) |
|
\(\diffdst\) |
Operation Details¶
N/A
Data Types Support¶
Inner product primitive supports the following combination of data types for source, destination, weights, and bias.
Note
Here we abbreviate data types names for readability. For example, dnnl::memory::data_type::f32
is
abbreviated to f32
.
Propagation |
Source |
Weights |
Destination |
Bias |
---|---|---|---|---|
forward / backward |
||||
forward |
||||
forward |
||||
forward |
||||
backward |
||||
weights update |
Data Representation¶
Like other CNN primitives, the inner product primitive expects the following tensors:
Spatial |
Source |
Destination |
Weights |
---|---|---|---|
1D |
\(N \times C \times W\) |
\(N \times C\) |
\(OC \times IC \times KW\) |
2D |
\(N \times C \times H \times W\) |
\(N \times C\) |
\(OC \times IC \times KH \times KW\) |
3D |
\(N \times C \times D \times H \times W\) |
\(N \times C\) |
\(OC \times IC \times KD \times KH \times KW\) |
Memory format of data and weights memory objects is critical for inner product
primitive performance. In the oneDNN programming model, inner product
primitive is one of the few primitives that support the placeholder format
any
and can define data and weight memory objects formats based on the
primitive parameters. When using any
it is necessary to first create an
inner product primitive descriptor and then query it for the actual data and
weight memory objects formats.
The table below shows the combinations for which plain memory formats the
inner product primitive is optimized for. For the destination tensor (which is
always \(N \times C\)) the memory format is always nc
(ab
).
Spatial |
Source / Weights logical tensor |
Implementation optimized for memory formats |
---|---|---|
0D |
NC / OI |
|
0D |
NC / OI |
|
1D |
NCW / OIW |
|
1D |
NCW / OIW |
|
2D |
NCHW / OIHW |
|
2D |
NCHW / OIHW |
|
3D |
NCDHW / OIDHW |
|
3D |
NCDHW / OIDHW |
Post-ops and Attributes¶
The following post-ops should be supported by inner product primitives:
Propagation |
Type |
Operation |
Description |
Restrictions |
---|---|---|---|---|
forward |
attribute |
Scales the result of inner product by given scale factor(s) |
int8 inner products only |
|
forward |
post-op |
Applies an elementwise operation to the result |
||
forward |
post-op |
Adds the operation result to the destination tensor instead of overwriting it |
API¶
-
struct dnnl::inner_product_forward : public dnnl::primitive¶
Inner product forward propagation primitive.
Public Functions
-
inner_product_forward()¶
Default constructor. Produces an empty object.
-
inner_product_forward(const primitive_desc &pd)¶
Constructs an inner product forward propagation primitive.
- Parameters
pd – Primitive descriptor for an inner product forward propagation primitive.
-
struct desc¶
Descriptor for an inner product forward propagation primitive.
Public Functions
-
desc(prop_kind aprop_kind, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc)¶
Constructs a descriptor for an inner product forward propagation primitive with bias.
Note
All the memory descriptors may be initialized with the dnnl::memory::format_tag::any value of
format_tag
.- Parameters
aprop_kind – Propagation kind. Possible values are dnnl::prop_kind::forward_training, and dnnl::prop_kind::forward_inference.
src_desc – Memory descriptor for src.
weights_desc – Memory descriptor for diff weights.
bias_desc – Memory descriptor for diff bias.
dst_desc – Memory descriptor for diff dst.
-
desc(prop_kind aprop_kind, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc)¶
Constructs a descriptor for an inner product forward propagation primitive without bias.
Note
All the memory descriptors may be initialized with the dnnl::memory::format_tag::any value of
format_tag
.- Parameters
aprop_kind – Propagation kind. Possible values are dnnl::prop_kind::forward_training, and dnnl::prop_kind::forward_inference.
src_desc – Memory descriptor for src.
weights_desc – Memory descriptor for diff weights.
dst_desc – Memory descriptor for dst.
-
desc(prop_kind aprop_kind, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc)¶
-
struct primitive_desc : public dnnl::primitive_desc¶
Primitive descriptor for an inner product forward propagation primitive.
Public Functions
-
primitive_desc()¶
Default constructor. Produces an empty object.
-
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty = false)¶
Constructs a primitive descriptor for an inner product forward propagation primitive.
- Parameters
adesc – Descriptor for an inner product forward propagation primitive.
aengine – Engine to use.
allow_empty – A flag signifying whether construction is allowed to fail without throwing an exception. In this case an empty object will be produced. This flag is optional and defaults to false.
-
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty = false)¶
Constructs a primitive descriptor for an inner product forward propagation primitive.
- Parameters
adesc – Descriptor for an inner product forward propagation primitive.
attr – Primitive attributes to use.
aengine – Engine to use.
allow_empty – A flag signifying whether construction is allowed to fail without throwing an exception. In this case an empty object will be produced. This flag is optional and defaults to false.
-
memory::desc src_desc() const¶
Returns a source memory descriptor.
- Returns
Source memory descriptor.
- Returns
A zero memory descriptor if the primitive does not have a source parameter.
-
memory::desc weights_desc() const¶
Returns a weights memory descriptor.
- Returns
Weights memory descriptor.
- Returns
A zero memory descriptor if the primitive does not have a weights parameter.
-
primitive_desc()¶
-
inner_product_forward()¶
-
struct dnnl::inner_product_backward_data : public dnnl::primitive¶
Inner product backward propagation primitive.
Public Functions
-
inner_product_backward_data()¶
Default constructor. Produces an empty object.
-
inner_product_backward_data(const primitive_desc &pd)¶
Constructs an inner product backward propagation primitive.
- Parameters
pd – Primitive descriptor for an inner product backward propagation primitive.
-
struct desc¶
Descriptor for an inner product backward propagation primitive.
Public Functions
-
desc(const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc)¶
Constructs a descriptor for an inner product backward propagation primitive.
Note
All the memory descriptors may be initialized with the dnnl::memory::format_tag::any value of
format_tag
.- Parameters
diff_src_desc – Memory descriptor for diff src.
weights_desc – Memory descriptor for weights.
diff_dst_desc – Memory descriptor for diff dst.
-
desc(const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc)¶
-
struct primitive_desc : public dnnl::primitive_desc¶
Primitive descriptor for an inner product backward propagation primitive.
Public Functions
-
primitive_desc()¶
Default constructor. Produces an empty object.
-
primitive_desc(const desc &adesc, const engine &aengine, const inner_product_forward::primitive_desc &hint_fwd_pd, bool allow_empty = false)¶
Constructs a primitive descriptor for an inner product backward propagation primitive.
- Parameters
adesc – Descriptor for an inner product backward propagation primitive.
aengine – Engine to use.
hint_fwd_pd – Primitive descriptor for an inner product forward propagation primitive. It is used as a hint for deciding which memory format to use.
allow_empty – A flag signifying whether construction is allowed to fail without throwing an exception. In this case an empty object will be produced. This flag is optional and defaults to false.
-
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const inner_product_forward::primitive_desc &hint_fwd_pd, bool allow_empty = false)¶
Constructs a primitive descriptor for an inner product backward propagation primitive.
- Parameters
adesc – Descriptor for an inner product backward propagation primitive.
attr – Primitive attributes to use.
aengine – Engine to use.
hint_fwd_pd – Primitive descriptor for an inner product forward propagation primitive. It is used as a hint for deciding which memory format to use.
allow_empty – A flag signifying whether construction is allowed to fail without throwing an exception. In this case an empty object will be produced. This flag is optional and defaults to false.
-
memory::desc diff_src_desc() const¶
Returns a diff source memory descriptor.
- Returns
Diff source memory descriptor.
- Returns
A zero memory descriptor if the primitive does not have a diff source memory with.
-
primitive_desc()¶
-
inner_product_backward_data()¶
-
struct dnnl::inner_product_backward_weights : public dnnl::primitive¶
Inner product weights gradient primitive.
Public Functions
-
inner_product_backward_weights()¶
Default constructor. Produces an empty object.
-
inner_product_backward_weights(const primitive_desc &pd)¶
Constructs an inner product weights gradient primitive.
- Parameters
pd – Primitive descriptor for an inner product weights gradient primitive.
-
struct desc¶
Descriptor for an inner product weights gradient primitive.
Public Functions
-
desc(const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc)¶
Constructs a descriptor for an inner product descriptor weights update primitive with bias.
Note
All the memory descriptors may be initialized with the dnnl::memory::format_tag::any value of
format_tag
.- Parameters
src_desc – Memory descriptor for src.
diff_weights_desc – Memory descriptor for diff weights.
diff_bias_desc – Memory descriptor for diff bias.
diff_dst_desc – Memory descriptor for diff dst.
-
desc(const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc)¶
Constructs a descriptor for an inner product descriptor weights update primitive without bias.
Note
All the memory descriptors may be initialized with the dnnl::memory::format_tag::any value of
format_tag
.- Parameters
src_desc – Memory descriptor for src.
diff_weights_desc – Memory descriptor for diff weights.
diff_dst_desc – Memory descriptor for diff dst.
-
desc(const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc)¶
-
struct primitive_desc : public dnnl::primitive_desc¶
Primitive descriptor for an inner product weights gradient primitive.
Public Functions
-
primitive_desc()¶
Default constructor. Produces an empty object.
-
primitive_desc(const desc &adesc, const engine &aengine, const inner_product_forward::primitive_desc &hint_fwd_pd, bool allow_empty = false)¶
Constructs a primitive descriptor for an inner product weights update primitive.
- Parameters
adesc – Descriptor for an inner product weights gradient primitive.
aengine – Engine to use.
hint_fwd_pd – Primitive descriptor for an inner product forward propagation primitive. It is used as a hint for deciding which memory format to use.
allow_empty – A flag signifying whether construction is allowed to fail without throwing an exception. In this case an empty object will be produced. This flag is optional and defaults to false.
-
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const inner_product_forward::primitive_desc &hint_fwd_pd, bool allow_empty = false)¶
Constructs a primitive descriptor for an inner product weights update primitive.
- Parameters
adesc – Descriptor for an inner product weights gradient primitive.
attr – Primitive attributes to use.
aengine – Engine to use.
hint_fwd_pd – Primitive descriptor for an inner product forward propagation primitive. It is used as a hint for deciding which memory format to use.
allow_empty – A flag signifying whether construction is allowed to fail without throwing an exception. In this case an empty object will be produced. This flag is optional and defaults to false.
-
memory::desc src_desc() const¶
Returns a source memory descriptor.
- Returns
Source memory descriptor.
- Returns
A zero memory descriptor if the primitive does not have a source parameter.
-
memory::desc diff_weights_desc() const¶
Returns a diff weights memory descriptor.
- Returns
Diff weights memory descriptor.
- Returns
A zero memory descriptor if the primitive does not have a diff weights parameter.
-
primitive_desc()¶
-
inner_product_backward_weights()¶