Clarification on `relative_shift` Output Dimension and Relative Distance Mapping in the preprint

Hi team,

Thank you for sharing this great work!

I have a question regarding the relative_shift function in your preprint.

The pseudo-code for relative_shift suggests an output dimension of [..., S, 2*S - 1]. However, the relative attention calculation within sequence_to_pair_block seems to treat its input (e.g., rel_q_a) as [..., S, S] (or [B, S, S, H] after batch and head dimensions are considered).

This suggests that relative_shift performs a critical transformation to reduce the 2 * S - 1 relative range down to S while correctly aligning the relative positional information for the (query_position, key_position) pairs.

Could you please elaborate on:

  1. The precise output dimension of relative_shift? Is it indeed [..., S, S], and if so, how is the 2 * S - 1 dimension mapped or trimmed to S within this function?
  2. How the specific relative distance (k−q) is precisely mapped to the (query_position, key_position) indices within the [..., S, S] output tensor? Understanding this mapping is crucial for grasping how the sequence-to-pair representation correctly incorporates relative positional information.

Hi!

In the revised version of the manuscript we are going to add the pseudocode for the relative_shift operation:

def relative_shift(x: Array[..., S, 2 * S - 1]) -> Array[..., S, S]:
  *batch_shapes, seq_length, num_diagonals = x.shape
  x = Pad(x, [(0,0)] * (len(batch_shapes)+1) + [(1,0)])
  x = x.reshape(batch_shapes + [num_diagonals + 1, seq_length])
  x = x[..., 1:, :].reshape(batch_shapes + [seq_length, num_diagonals])
  return x[..., :seq_length]

Which does, for example:

>>> S = 4
>>> D = 2 * S - 1
>>> example = np.linspace(1 - S, D - S, D)
>>> diagonals_as_cols = np.stack([example for _ in range(4)])
>>> diagonals_as_cols
array([[-3., -2., -1.,  0.,  1.,  2.,  3.],
       [-3., -2., -1.,  0.,  1.,  2.,  3.],
       [-3., -2., -1.,  0.,  1.,  2.,  3.],
       [-3., -2., -1.,  0.,  1.,  2.,  3.]])
>>> relative_shift(diagonals_as_cols)
array([[ 0.,  1.,  2.,  3.],
       [-1.,  0.,  1.,  2.],
       [-2., -1.,  0.,  1.],
       [-3., -2., -1.,  0.]])