Resampling

The resampling primitive computes forward or backward resampling operation on 1D, 2D, or 3D spatial data. Resampling performs spatial scaling of original tensor using one of the supported interpolation algorithms:

  • Nearest Neighbor

  • Linear (or Bilinear for 2D spatial tensor, Trilinear for 3D spatial tensor).

Resampling operation is defined by the source tensor and scaling factors in each spatial dimension. Upsampling and downsampling are the alternative terms for resampling that are used when all scaling factors are greater (upsampling) or less (downsampling) than one.

The resampling operation is defined by the following formulas. We show formulas only for 2D spatial data which are straightforward to generalize to cases of higher and lower dimensions. Variable names follow the standard Conventions.

Let \(\src\) and \(\dst\) be \(N \times C \times IH \times IW\) and \(N \times C \times OH \times OW\) tensors respectively. Let \(F_h = \frac{OH}{IH}\) and \(F_w = \frac{OW}{IW}\) define scaling factors in each spatial dimension.

The following formulas show how oneDNN computes resampling for nearest neighbor and bilinear interpolation methods. To further simplify the formulas, we assume the following:

  • \(\src(n, ic, ih, iw) = 0\) if \(ih < 0\) or \(iw < 0\),

  • \(\src(n, ic, ih, iw) = \src(n, ic, IH - 1, iw)\) if \(ih \geq IH\),

  • \(\src(n, ic, ih, iw) = \src(n, ic, ih, IW - 1)\) if \(iw \geq IW\).

Forward

Nearest Neighbor Resampling

\[\dst(n, c, oh, ow) = \src(n, c, ih, iw)\]

where

  • \(ih = [\frac{oh + 0.5} {F_h} - 0.5]\),

  • \(iw = [\frac{ow + 0.5} {F_w} - 0.5]\).

Bilinear Resampling

\[\begin{split}\dst(n, c, oh, ow) = \src(n, c, ih_0, iw_0) \cdot W_{ih} \cdot W_{iw} + \\ \src(n, c, ih_1, iw_0) \cdot (1 - W_{ih}) \cdot W_{iw} + \\ \src(n, c, ih_0, iw_1) \cdot W_{ih} \cdot (1 - W_{iw}) + \\ \src(n, c, ih_1, iw_1) \cdot (1 - W_{ih}) \cdot (1 - W_{iw}) \\\end{split}\]

where

  • \(ih_0 = \left\lfloor{\frac {oh + 0.5} {F_h} - 0.5}\right\rfloor\),

  • \(ih_1 = \left\lceil {\frac {oh + 0.5} {F_h} - 0.5}\right\rceil\),

  • \(iw_0 = \left\lfloor{\frac {ow + 0.5} {F_w} - 0.5}\right\rfloor\),

  • \(iw_1 = \left\lceil {\frac {ow + 0.5} {F_w} - 0.5}\right\rceil\),

  • \(W_{ih} = \frac{oh + 0.5}{F_h} - 0.5 - ih_0\),

  • \(W_{iw} = \frac{ow + 0.5}{F_w} - 0.5 - iw_0\).

Difference Between Forward Training and Forward Inference

There is no difference between the forward_training and forward_inference propagation kinds.

Backward

The backward propagation computes \(\diffsrc\) based on \(\diffdst\).

Execution Arguments

When executed, the inputs and outputs should be mapped to an execution argument index as specified by the following table.

Primitive input/output

Execution argument index

\(\src\)

DNNL_ARG_SRC

\(\dst\)

DNNL_ARG_DST

\(\diffsrc\)

DNNL_ARG_DIFF_SRC

\(\diffdst\)

DNNL_ARG_DIFF_DST

Operation Details

  1. Resampling implementation supports data with arbitrary data tag (nchw, nhwc, etc.) but memory tags for src and dst are expected to be the same. Resampling primitive supports dst and diff_src memory tag any and can define destination format based on source format.

  2. Resampling descriptor can be created by specifying the source and destination memory descriptors, only the source descriptor and floating point factors, or the source and destination memory descriptors and factors. In case when user does not provide the destination descriptor, the destination dimensions are deduced using the factors: \(output\_spatial\_size = \left\lfloor{\frac{input\_spatial\_size} {F}}\right\rfloor\).

Note

Resampling algorithm uses factors as defined by the relation \(F = \frac{output\_spatial\_size} {input\_spatial\_size}\) that do not necessarily equal to the ones passed by the user.

Data Types Support

Resampling primitive supports the following combination of data types for source and destination memory objects.

Note

Here we abbreviate data types names for readability. For example, dnnl::memory::data_type::f32 is abbreviated to f32.

Propagation

Source / Destination

forward / backward

f32, bf16

forward

f16, s8, u8

Post-ops and Attributes

The resampling primitive does not support any post-ops or attributes.

API

struct dnnl::resampling_forward : public dnnl::primitive

Resampling forward propagation.

Public Functions

resampling_forward()

Default constructor. Produces an empty object.

resampling_forward(const primitive_desc &pd)

Constructs a resampling forward propagation primitive.

Parameters
  • pd: Primitive descriptor for a resampling forward propagation primitive.

struct desc

Descriptor for resampling forward propagation.

Public Functions

desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &dst_desc)

Constructs a descriptor for a resampling forward propagation primitive using source and destination memory descriptors.

Note

The destination memory descriptor may be initialized with dnnl::memory::format_tag::any value of format_tag.

Parameters

desc(prop_kind aprop_kind, algorithm aalgorithm, const std::vector<float> &factors, const memory::desc &src_desc)

Constructs a descriptor for a resampling forward propagation primitive using source memory descriptor and factors.

Parameters

desc(prop_kind aprop_kind, algorithm aalgorithm, const std::vector<float> &factors, const memory::desc &src_desc, const memory::desc &dst_desc)

Constructs a descriptor for a resampling forward propagation primitive.

Note

The destination memory descriptor may be initialized with dnnl::memory::format_tag::any value of format_tag.

Parameters

struct primitive_desc : public dnnl::primitive_desc

Primitive descriptor for a resampling forward propagation primitive.

Public Functions

primitive_desc()

Default constructor. Produces an empty object.

primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty = false)

Constructs a primitive descriptor for a resampling forward propagation primitive.

Parameters
  • adesc: Descriptor for a resampling forward propagation primitive.

  • aengine: Engine to use.

  • allow_empty: A flag signifying whether construction is allowed to fail without throwing an exception. In this case an empty object will be produced. This flag is optional and defaults to false.

primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty = false)

Constructs a primitive descriptor for a resampling forward propagation primitive.

Parameters
  • adesc: Descriptor for a resampling forward propagation primitive.

  • aengine: Engine to use.

  • attr: Primitive attributes to use.

  • allow_empty: A flag signifying whether construction is allowed to fail without throwing an exception. In this case an empty object will be produced. This flag is optional and defaults to false.

memory::desc src_desc() const

Returns a source memory descriptor.

Return

Source memory descriptor.

Return

A zero memory descriptor if the primitive does not have a source parameter.

memory::desc dst_desc() const

Returns a destination memory descriptor.

Return

Destination memory descriptor.

Return

A zero memory descriptor if the primitive does not have a destination parameter.

struct dnnl::resampling_backward : public dnnl::primitive

Resampling backward propagation primitive.

Public Functions

resampling_backward()

Default constructor. Produces an empty object.

resampling_backward(const primitive_desc &pd)

Constructs a resampling backward propagation primitive.

Parameters
  • pd: Primitive descriptor for a resampling backward propagation primitive.

struct desc

Descriptor for a resampling backward propagation primitive.

Public Functions

desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &diff_dst_desc)

Constructs a descriptor for a resampling backward propagation primitive using source and destination memory descriptors.

Parameters

desc(algorithm aalgorithm, const std::vector<float> &factors, const memory::desc &diff_src_desc, const memory::desc &diff_dst_desc)

Constructs a descriptor for resampling backward propagation primitive.

Parameters

struct primitive_desc : public dnnl::primitive_desc

Primitive descriptor for resampling backward propagation primitive.

Public Functions

primitive_desc()

Default constructor. Produces an empty object.

primitive_desc(const desc &adesc, const engine &aengine, const resampling_forward::primitive_desc &hint_fwd_pd, bool allow_empty = false)

Constructs a primitive descriptor for a resampling backward propagation primitive.

Parameters
  • adesc: Descriptor for a resampling backward propagation primitive.

  • aengine: Engine to use.

  • hint_fwd_pd: Primitive descriptor for a resampling forward propagation primitive. It is used as a hint for deciding which memory format to use.

  • allow_empty: A flag signifying whether construction is allowed to fail without throwing an exception. In this case an empty object will be produced. This flag is optional and defaults to false.

primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const resampling_forward::primitive_desc &hint_fwd_pd, bool allow_empty = false)

Constructs a primitive descriptor for a resampling backward propagation primitive.

Parameters
  • adesc: Descriptor for a resampling backward propagation primitive.

  • attr: Primitive attributes to use.

  • aengine: Engine to use.

  • hint_fwd_pd: Primitive descriptor for a resampling forward propagation primitive. It is used as a hint for deciding which memory format to use.

  • allow_empty: A flag signifying whether construction is allowed to fail without throwing an exception. In this case an empty object will be produced. This flag is optional and defaults to false.

memory::desc diff_src_desc() const

Returns a diff source memory descriptor.

Return

Diff source memory descriptor.

Return

A zero memory descriptor if the primitive does not have a diff source memory with.

memory::desc diff_dst_desc() const

Returns a diff destination memory descriptor.

Return

Diff destination memory descriptor.

Return

A zero memory descriptor if the primitive does not have a diff destination parameter.