Hi team,
Thank you for sharing this great work!
I have a question regarding the relative_shift function in your preprint.
The pseudo-code for relative_shift suggests an output dimension of [..., S, 2*S - 1]. However, the relative attention calculation within sequence_to_pair_block seems to treat its input (e.g., rel_q_a) as [..., S, S] (or [B, S, S, H] after batch and head dimensions are considered).
This suggests that relative_shift performs a critical transformation to reduce the 2 * S - 1 relative range down to S while correctly aligning the relative positional information for the (query_position, key_position) pairs.
Could you please elaborate on:
- The precise output dimension of
relative_shift? Is it indeed[..., S, S], and if so, how is the2 * S - 1dimension mapped or trimmed toSwithin this function? - How the specific relative distance (k−q) is precisely mapped to the
(query_position, key_position)indices within the[..., S, S]output tensor? Understanding this mapping is crucial for grasping how the sequence-to-pair representation correctly incorporates relative positional information.