Matrix Multiplication#

The matrix multiplication (MatMul) primitive computes the product of two 2D tensors with optional bias addition. Variable names follow the standard Conventions.

\[\dst(m, n) = \sum_{k=0}^{K - 1} \left( \src(m, k) \cdot \weights(k, n) \right) + \bias(m, n)\]

The MatMul primitive also supports batching multiple independent matrix multiplication operations, in which case the tensors must be 3D:

\[\dst(mb, m, n) = \sum_{k=0}^{K - 1} \left( \src(mb, m, k) \cdot \weights(mb, k, n) \right) + \bias(mb, m, n)\]

The bias tensor is optional and supports implicit broadcast semantics: any of its dimensions can be 1 and the same value would be used across the corresponding dimension. However, \(\bias\) must have the same number of dimensions as the \(\dst\).

Execution Arguments#

When executed, the inputs and outputs should be mapped to an execution argument index as specified by the following table.

Primitive input/output

Execution argument index

\(\src\)

DNNL_ARG_SRC

\(\weights\)

DNNL_ARG_WEIGHTS

\(\bias\)

DNNL_ARG_BIAS

\(\dst\)

DNNL_ARG_DST

Operation Details#

The MatMul primitive supports input and output tensors with run-time specified shapes and memory formats. The run-time specified dimensions or strides are specified using the DNNL_RUNTIME_DIM_VAL wildcard value during the primitive initialization and creation stage. At the execution stage, the user must pass fully specified memory objects so that the primitive is able to perform the computations. Note that the less information about shapes or format is available at the creation stage, the less performant execution will be. In particular, if the shape is not known at creation stage, one cannot use the special format tag any to enable an implementation to choose the most appropriate memory format for the corresponding input or output shapes. On the other hand, run-time specified shapes enable users to create a primitive once and use it in different situations.

Data Types Support#

The MatMul primitive supports the following combinations of data types for source, destination, weights, and bias tensors.

Note

Here we abbreviate data types names for readability. For example, dnnl::memory::data_type::f32 is abbreviated to f32.

Source

Weights

Destination

Bias

f32

f32

f32

f32

f16

f16

f16

f16

bf16

bf16

bf16

bf16, f32

u8, s8

s8, u8

u8, s8, s32, f32

u8, s8, s32, f32

Data Representation#

The MatMul primitive expects the following tensors:

Dims

Source

Weights

Destination

Bias (optional)

2D

\(M \times K\)

\(K \times N\)

\(M \times N\)

\((M \text{ or } 1)\) \(\times (N \text{ or } 1)\)

3D

\(MB \times M \times K\)

\(MB \times K \times N\)

\(MB \times M \times N\)

\((MB \text{ or } 1)\) \(\times (M \text{ or } 1)\) \(\times (N \text{ or } 1)\)

The MatMul primitive is generally optimized for the case in which memory objects use plain memory formats (with some restrictions; see the table below). However, it is recommended to use the placeholder memory format any if an input tensor is reused across multiple executions. In this case, the primitive will set the most appropriate memory format for the corresponding input tensor.

The table below shows the combinations of memory formats for which the MatMul primitive is optimized. The memory format of the destination tensor should always be ab for the 2D case and abc for the 3D one.

Dims

Logical tensors

MatMul is optimized for the following memory formats

2D

Source: \(M \times K\), Weights: \(K \times N\)

Source: ab or ba, Weights: ab or ba

3D

Source: \(MB \times M \times K\), Weights: \(MB \times K \times N\)

Source: abc or acb, Weights: abc or acb

Attributes and Post-ops#

Attributes and post-ops enable modifying the behavior of the MatMul primitive. The following attributes and post-ops are supported:

Type

Operation

Description

Restrictions

Attribute

Scales

Sets scale(s) for the corresponding tensor(s)

Attribute

Zero points

Sets zero point(s) for the corresponding tensors

Int8 computations only

Post-op

Eltwise | Applies an elementwise operation to the result

Post-op

Binary | Applies a binary operation to the result

Post-op

Sum

Adds the operation result to the destination tensor instead of overwriting it

The primitive supports dynamic quantization via run-time scales. That means a user could configure the scales and zero-point attributes at the primitive descriptor creation stage. The user must then provide the scales and zero-points as an additional input memory objects with argument DNNL_ARG_ATTR_SCALES and DNNL_ARG_ATTR_ZERO_POINTS during the execution stage (more details are provided in the Quantization section).

API#

struct dnnl::matmul : public dnnl::primitive#

Matrix multiplication (matmul) primitive.

Public Functions

matmul()#

Default constructor. Produces an empty object.

matmul(const primitive_desc &pd)#

Constructs a matmul primitive.

Parameters

pd – Primitive descriptor for a matmul primitive.

struct primitive_desc : public dnnl::primitive_desc#

Primitive descriptor for a matmul primitive.

Public Functions

primitive_desc() = default#

Default constructor. Produces an empty object.

primitive_desc(const engine &aengine, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc, const primitive_attr &attr = default_attr(), bool allow_empty = false)#

Constructs a primitive descriptor for a matmul primitive without bias.

Parameters
  • aengine – Engine to use.

  • src_desc – Memory descriptor for source (matrix A).

  • weights_desc – Memory descriptor for weights (matrix B).

  • dst_desc – Memory descriptor for destination (matrix C).

  • attr – Primitive attributes to use. Attributes are optional and default to empty attributes.

  • allow_empty – A flag signifying whether construction is allowed to fail without throwing an exception. In this case an empty object will be produced. This flag is optional and defaults to false.

primitive_desc(const engine &aengine, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc, const primitive_attr &attr = default_attr(), bool allow_empty = false)#

Constructs a primitive descriptor for a matmul primitive with bias.

Parameters
  • aengine – Engine to use.

  • src_desc – Memory descriptor for source (matrix A).

  • weights_desc – Memory descriptor for weights (matrix B).

  • dst_desc – Memory descriptor for destination (matrix C).

  • bias_desc – Memory descriptor for bias.

  • attr – Primitive attributes to use. Attributes are optional and default to empty attributes.

  • allow_empty – A flag signifying whether construction is allowed to fail without throwing an exception. In this case an empty object will be produced. This flag is optional and defaults to false.

memory::desc src_desc() const#

Returns a source memory descriptor.

Returns

Source memory descriptor.

Returns

A zero memory descriptor if the primitive does not have a source parameter.

memory::desc weights_desc() const#

Returns a weights memory descriptor.

Returns

Weights memory descriptor.

Returns

A zero memory descriptor if the primitive does not have a weights parameter.

memory::desc bias_desc() const#

Returns the bias memory descriptor.

Returns

The bias memory descriptor.

Returns

A zero memory descriptor of the primitive does not have a bias parameter.

memory::desc dst_desc() const#

Returns a destination memory descriptor.

Returns

Destination memory descriptor.

Returns

A zero memory descriptor if the primitive does not have a destination parameter.