Matrix Multiplication¶
The matrix multiplication (MatMul) primitive computes the product of two 2D tensors with optional bias addition. Variable names follow the standard Conventions.
The MatMul primitive also supports batching multiple independent matrix multiplication operations, in which case the tensors must be 3D:
The bias tensor is optional and supports implicit broadcast semantics: any of its dimensions can be 1 and the same value would be used across the corresponding dimension. However, \(\bias\) must have the same number of dimensions as the \(\dst\).
Execution Arguments¶
When executed, the inputs and outputs should be mapped to an execution argument index as specified by the following table.
Primitive input/output |
Execution argument index |
---|---|
\(\src\) |
|
\(\weights\) |
|
\(\bias\) |
|
\(\dst\) |
Operation Details¶
The MatMul primitive supports input and output tensors with run-time specified
shapes and memory formats. The run-time specified dimensions or strides are
specified using the DNNL_RUNTIME_DIM_VAL
wildcard value during the primitive
initialization and creation stage. At the execution stage, the user must pass
fully specified memory objects so that the primitive is able to perform the
computations. Note that the less information about shapes or format is
available at the creation stage, the less performant execution will be. In
particular, if the shape is not known at creation stage, one cannot use the
special format tag any
to enable an implementation to choose the most
appropriate memory format for the corresponding input or output shapes. On the
other hand, run-time specified shapes enable users to create a primitive once
and use it in different situations.
Data Types Support¶
The MatMul primitive supports the following combinations of data types for source, destination, weights, and bias tensors.
Note
Here we abbreviate data types names for readability. For example, dnnl::memory::data_type::f32
is
abbreviated to f32
.
Source |
Weights |
Destination |
Bias |
---|---|---|---|
Data Representation¶
The MatMul primitive expects the following tensors:
Dims |
Source |
Weights |
Destination |
Bias (optional) |
---|---|---|---|---|
2D |
\(M \times K\) |
\(K \times N\) |
\(M \times N\) |
\((M \text{ or } 1)\) \(\times (N \text{ or } 1)\) |
3D |
\(MB \times M \times K\) |
\(MB \times K \times N\) |
\(MB \times M \times N\) |
\((MB \text{ or } 1)\) \(\times (M \text{ or } 1)\) \(\times (N \text{ or } 1)\) |
The MatMul primitive is generally optimized for the case in which memory
objects use plain memory formats (with some restrictions; see the table
below). However, it is recommended to use the placeholder memory format any
if an input tensor is reused across multiple executions. In this case, the
primitive will set the most appropriate memory format for the corresponding
input tensor.
The table below shows the combinations of memory formats for which the MatMul
primitive is optimized. The memory format of the destination tensor should
always be ab
for the 2D case and abc
for the 3D one.
Dims |
Logical tensors |
MatMul is optimized for the following memory formats |
---|---|---|
2D |
Source: \(M \times K\), Weights: \(K \times N\) |
|
3D |
Source: \(MB \times M \times K\), Weights: \(MB \times K \times N\) |
Attributes and Post-ops¶
Attributes and post-ops enable modifying the behavior of the MatMul primitive. The following attributes and post-ops are supported:
Type |
Operation |
Description |
Restrictions |
---|---|---|---|
Attribute |
Scales the result by given scale factor(s) |
||
Attribute |
Sets zero point(s) for the corresponding tensors |
Int8 computations only |
|
Post-op |
Applies an elementwise operation to the result |
||
Post-op |
Adds the operation result to the destination tensor instead of overwriting it |
To facilitate dynamic quantization, the primitive should support run-time
output scales. That means a user could configure attributes with output scales
set to the DNNL_RUNTIME_F32_VAL
wildcard value instead of the actual scales,
if the scales are not known at the primitive descriptor creation stage. In
this case, the user must provide the scales as an additional input memory
object with argument DNNL_ARG_ATTR_OUTPUT_SCALES
during the execution stage.
Similarly to run-time output scales, the primitive supports run-time zero
points. The wildcard value for zero points is DNNL_RUNTIME_S32_VAL
. During
the execution stage, the corresponding memory object needs to be passed in the
argument with index set to (DNNL_ARG_ATTR_ZERO_POINTS |
DNNL_ARG_${MEMORY}
). For instance, source tensor zero points memory argument
would be passed with index (DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC
).
API¶
-
struct dnnl::matmul : public dnnl::primitive¶
Matrix multiplication (matmul) primitive.
Public Functions
-
matmul()¶
Default constructor. Produces an empty object.
-
matmul(const primitive_desc &pd)¶
Constructs a matmul primitive.
- Parameters
pd – Primitive descriptor for a matmul primitive.
-
struct desc¶
Descriptor for a matmul primitive.
Public Functions
-
desc(const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc)¶
Constructs a descriptor for a matmul primitive.
- Parameters
src_desc – Memory descriptor for source (matrix A).
weights_desc – Memory descriptor for weights (matrix B).
dst_desc – Memory descriptor for destination (matrix C).
-
desc(const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc)¶
Constructs a descriptor for a matmul primitive.
- Parameters
src_desc – Memory descriptor for source (matrix A).
weights_desc – Memory descriptor for weights (matrix B).
dst_desc – Memory descriptor for destination (matrix C).
bias_desc – Memory descriptor for bias.
-
desc(const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc)¶
-
struct primitive_desc : public dnnl::primitive_desc¶
Primitive descriptor for a matmul primitive.
Public Functions
-
primitive_desc()¶
Default constructor. Produces an empty object.
-
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty = false)¶
Constructs a primitive descriptor for a matmul primitive.
- Parameters
adesc – Descriptor for a matmul primitive.
aengine – Engine to use.
allow_empty – A flag signifying whether construction is allowed to fail without throwing an exception. In this case an empty object will be produced. This flag is optional and defaults to false.
-
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty = false)¶
Constructs a primitive descriptor for a matmul primitive.
- Parameters
adesc – Descriptor for a matmul primitive.
attr – Primitive attributes to use.
aengine – Engine to use.
allow_empty – A flag signifying whether construction is allowed to fail without throwing an exception. In this case an empty object will be produced. This flag is optional and defaults to false.
-
memory::desc src_desc() const¶
Returns a source memory descriptor.
- Returns
Source memory descriptor.
- Returns
A zero memory descriptor if the primitive does not have a source parameter.
-
memory::desc weights_desc() const¶
Returns a weights memory descriptor.
- Returns
Weights memory descriptor.
- Returns
A zero memory descriptor if the primitive does not have a weights parameter.
-
primitive_desc()¶
-
matmul()¶