Concat#

A primitive to concatenate data by arbitrary dimension.

The concat primitive concatenates \(N\) tensors over concat_dimension (here denoted as \(C\)), and is defined as

\[\dst(\overline{ou}, c, \overline{in}) = \src_i(\overline{ou}, c', \overline{in}),\]

where

  • \(c = C_1 + \ldots + C_{i-1} + c'\),

  • \(\overline{ou}\) is the outermost indices (to the left from concat axis),

  • \(\overline{in}\) is the innermost indices (to the right from concat axis), and

Variable names follow the standard Conventions.

Forward and Backward#

The concat primitive does not have a notion of forward or backward propagations. The backward propagation for the concatenation operation is simply an identity operation.

Execution Arguments#

When executed, the inputs and outputs should be mapped to an execution argument index as specified by the following table.

Primitive input/output

Execution argument index

\(\src\)

DNNL_ARG_MULTIPLE_SRC

\(\dst\)

DNNL_ARG_DST

Operation Details#

  1. The \(\dst\) memory format can be either specified by a user or derived by the primitive. The recommended way is to allow the primitive to choose the most appropriate format.

  2. The concat primitive requires all source and destination tensors to have the same shape except for the concat_dimension. The destination dimension for the concat_dimension must be equal to the sum of the concat_dimension dimensions of the sources (i.e. \(C = \sum_i C_i\)). Implicit broadcasting is not supported.

Data Types Support#

The concat primitive supports arbitrary data types for source and destination tensors. However, it is required that all source tensors are of the same data type (but not necessarily matching the data type of the destination tensor).

Data Representation#

The concat primitive does not assign any special meaning associated with any logical dimensions.

Post-ops and Attributes#

The concat primitive does not support any post-ops or attributes.

API#

struct dnnl::concat : public dnnl::primitive#

Tensor concatenation (concat) primitive.

Public Functions

concat()#

Default constructor. Produces an empty object.

concat(const primitive_desc &pd)#

Constructs a concatenation primitive.

Parameters

pd – Primitive descriptor for concatenation primitive.

struct primitive_desc : public dnnl::primitive_desc_base#

Primitive descriptor for a concat primitive.

Public Functions

primitive_desc()#

Default constructor. Produces an empty object.

primitive_desc(const memory::desc &dst, int concat_dimension, const std::vector<memory::desc> &srcs, const engine &aengine, const primitive_attr &attr = primitive_attr())#

Constructs a primitive descriptor for an out-of-place concatenation primitive.

Parameters
  • dst – Destination memory descriptor.

  • concat_dimension – Source tensors will be concatenated over dimension with this index. Note that order of dimensions does not depend on memory format.

  • srcs – Vector of source memory descriptors.

  • aengine – Engine to perform the operation on.

  • attr – Primitive attributes to use (optional).

primitive_desc(int concat_dimension, const std::vector<memory::desc> &srcs, const engine &aengine, const primitive_attr &attr = primitive_attr())#

Constructs a primitive descriptor for an out-of-place concatenation primitive.

This version derives the destination memory descriptor automatically.

Parameters
  • concat_dimension – Source tensors will be concatenated over dimension with this index. Note that order of dimensions does not depend on memory format.

  • srcs – Vector of source memory descriptors.

  • aengine – Engine to perform the operation on.

  • attr – Primitive attributes to use (optional).

memory::desc src_desc(int idx = 0) const#

Returns a source memory descriptor.

Parameters

idx – Source index.

Returns

Source memory descriptor.

Returns

A zero memory descriptor if the primitive does not have a source parameter with index pdx.

memory::desc dst_desc() const#

Returns a destination memory descriptor.

Returns

Destination memory descriptor.

Returns

A zero memory descriptor if the primitive does not have a destination parameter.