Possible mistake in the implementation of the randomised relative positional encoding?
yochaiye opened this issue · 1 comments
Hi,
Thank you for the interesting paper and the sharing the implementation.
I would like to use non causal randomised relative positional encodings in my project.
I was wondering whether there is a mistake in the arguments passed to sinusoid_position_encoding
in the compute_attention_with_noisy_relative_encodings
function:
it seems like sequence_length
argument is set to noise_max_length
while the add_negative_side
argument is set to False. Doesn't it conflict with the fact the indexes
in line 497
is in the range [0, 2*noise_max_length-1]?
I assume that add_negative_side
should be set to causal
(defined as the input to compute_attention_with_noisy_relative_encodings
)?
Yes, we fixed this internally, but not externally yet. The fixed functions are:
def sinusoid_position_encoding(
sequence_length: int,
hidden_size: int,
max_timescale: float = 1e4,
add_negative_side: bool = False,
keep_positive_side: bool = True,
) -> np.ndarray:
"""Creates sinusoidal encodings from the original transformer paper.
The returned values are, for all i < D/2:
array[pos, i] = sin(pos / (max_timescale^(2*i / D)))
array[pos, D/2 + i] = cos(pos / (max_timescale^(2*i / D)))
Args:
sequence_length: Sequence length.
hidden_size: Dimension of the positional encoding vectors, D. Should be
even.
max_timescale: Maximum timescale for the frequency.
add_negative_side: Whether to also include the positional encodings for
negative positions (we need both sides for relative encodings with `causal
= False`).
keep_positive_side: Whether to keep the positional encodings for the
positive positions (we do not want them for relative encodings with
`causal = True`).
Returns:
An array of shape [L, D] if `add_negative` or `keep_positive_side` is
`False`, else [2 * L, D].
"""
if not add_negative_side and not keep_positive_side:
raise ValueError(
'Either `add_negative_side` or `keep_positive_side` must be `True`.'
)
freqs = np.arange(0, hidden_size + 1, 2)
inv_freq = max_timescale ** (-freqs / hidden_size)
start = (
0 if not add_negative_side else (1 - keep_positive_side) - sequence_length
)
stop = sequence_length if keep_positive_side else 1
pos_seq = np.arange(start=start, stop=stop)
sinusoid_inp = np.einsum('i,j->ij', pos_seq, inv_freq)
embeddings = np.concatenate(
[np.sin(sinusoid_inp), np.cos(sinusoid_inp)], axis=-1
)
return embeddings[:, :hidden_size]
and
def compute_attention_with_noisy_relative_encodings(
queries: chex.Array,
keys: chex.Array,
noise_max_length: int,
randomize_both_sides: bool = False,
max_time: int = 10_000,
) -> chex.Array:
"""Returns attention with *noisy* relative positional encodings.
This code follows what is described in the TransformerXL paper.
https://arxiv.org/pdf/1901.02860.pdf
However, in this version, the base positional encodings R (which are then
shifted), are randomly sampled and ordered from a wider range than the
sequence length.
Args:
queries: The queries used for attention. Shape (b, t, h, d).
keys: The keys used for attention. Shape (b, T, h, d).
noise_max_length: The maximum length used to sample the encodings.
randomize_both_sides: Whether to sample the encodings on the left and on the
right of the current token, or just sample from the left and take the
inverted ones for the right part.
max_time: Constant used to scale position by in the sin/cos encodings.
Returns:
The attention logits. Shape (b, h, t, T).
"""
batch_size, k_seq_len, num_heads, num_hiddens = keys.shape
hiddens = num_hiddens * num_heads
# First compute the content logits.
content_bias = hk.get_parameter(
name='relpos_contentbias',
shape=[num_heads, num_hiddens],
init=hk.initializers.RandomNormal(stddev=0.02),
)
content_logits = jnp.einsum('bthd,bThd->bhtT', queries + content_bias, keys)
# Select random indexes.
# The indexes are in the range [-noise_max_length + 1, noise_max_length - 1].
right_indexes = jrandom.choice(
hk.next_rng_key(),
jnp.arange(1, noise_max_length),
shape=(k_seq_len - 1,),
replace=False,
)
right_indexes = jnp.sort(right_indexes)
if randomize_both_sides:
left_indexes = jrandom.choice(
hk.next_rng_key(),
jnp.arange(start=-noise_max_length + 1, stop=0),
shape=(k_seq_len,),
replace=False,
)
left_indexes = jnp.sort(left_indexes)
else:
left_indexes = -right_indexes[::-1]
# The leftmost index is required by position_embedding.relative_shift.
left_indexes = jnp.concatenate([jnp.zeros((1,)), left_indexes])
zero_index = jnp.zeros((1,))
indexes = jnp.concatenate([left_indexes, zero_index, right_indexes])
# We shift the indexes to the range [0, 2*noise_max_length-1], since this
# will be the range of the sin/cos. In this array, the value at index
# noise_max_length is the sin/cos encoding at position 0, which is exactly
# what we want: when doing relative attention, the token should have a fixed
# encoding of position 0 for its own position.
indexes += noise_max_length
indexes = jnp.array(indexes, dtype=jnp.int32)
positional_encodings = sinusoid_position_encoding(
sequence_length=noise_max_length,
hidden_size=hiddens,
max_timescale=max_time,
add_negative_side=True,
)
positional_encodings = jnp.array(positional_encodings, dtype=jnp.float32)
positional_encodings = positional_encodings[indexes]
positional_encodings = jnp.broadcast_to(
positional_encodings, (batch_size,) + positional_encodings.shape
)
relative_keys = hk.Linear(hiddens, with_bias=False)(positional_encodings)
relative_keys = jnp.reshape(
relative_keys, positional_encodings.shape[:-1] + (num_heads, num_hiddens)
)
# Then compute the relative part.
relative_bias = hk.get_parameter(
name='relpos_relativebias',
shape=[num_heads, num_hiddens],
init=hk.initializers.RandomNormal(stddev=0.02),
)
relative_logits = jnp.einsum(
'bthd,bThd->bhtT', queries + relative_bias, relative_keys
)
# We shift the relative logits instead of the positional encoding matrix as
# described in Appendix B of the paper (https://arxiv.org/pdf/1901.02860.pdf).
relative_logits = position_embedding.relative_shift(
relative_logits,
attention_length=content_logits.shape[-1],
causal=False,
)
assert content_logits.shape == relative_logits.shape
return content_logits + relative_logits