rwth-i6/returnn

RF masked computation / masking (like masked_select but without the packing)

Closed this issue · 3 comments

I currently need the masking functionality (which we have in TF MaskedComputationLayer, e.g. when you just pass in a CopyLayer), e.g. given a tensor [B,T,D], and a mask [B,T], I want to get out a new tensor [B,T',D], where T' is the reduced axis.

I wonder a bit that we don't already have such a function.

When looking for it, I first found masked_select, and wondered for a while how this is different. Then I realized, masked_select is also packing the tensor. That feels a bit inconsistent now, e.g. compared to gather. Why should masked_select do both such masked-selection and then followed by the packing? Maybe the packing could be itself an atomic operation. (Currently, the pack_padded is actually using sequence_mask + masked_select, because the masked_select does the packing.) The naming comes from PyTorch, where masked_select also does this. It's also the same as TF boolean_mask.

Here some TF code which does this (extracted from our MaskedComputationLayer):

# mask is the layer with the mask, e.g. [B,T]
assert (
    mask.output.have_time_axis() and mask.output.shape == (None,) and mask.output.dtype == "bool"
), "%s: invalid mask %s (outside rec loop)" % (self, mask)
assert in_spatial_dim and out_spatial_dim
mask_data = mask.output.copy_as_time_major()
mask_t = where_bc(mask_data.placeholder, mask_data.get_sequence_mask(), tf.convert_to_tensor(False))
idxs = tf.cumsum(tf.cast(mask_t, tf.int32), axis=0)  # [T,B] -> idx in T' + 1
new_size = idxs[-1]  # [B]
out_spatial_dim = out_spatial_dim.get_for_batch_ctx(self.output.batch, self.output.control_flow_ctx)
if out_spatial_dim.dyn_size is None:
    out_spatial_dim.dyn_size = new_size
new_time = tf.reduce_max(new_size)  # T'
idxs = where_bc(mask_t, idxs - 1, new_time)

And then:

tmp_shape = get_shape(source_data.placeholder)
tmp_shape[0] = new_time + 1  # one more for the padded data
res = tf.scatter_nd(nd_indices(idxs, batch_axis=1), source_data.placeholder, shape=tmp_shape)
res_data = source_data.copy_template().copy_template_replace_dim_tag(
    axis=0, new_dim_tag=out_spatial_dim
)
res_data.placeholder = res[:new_time]

For a pure RF implementation, we could follow the logic of TF code, i.e. using cumsum on the mask, then where to mask out the masked frames , then scatter to copy it into the resulting tensor. The where before would put all masked frames (also padded frames) to some dummy last frame in the resulting tensor (tmp_shape[0] = new_time + 1 # one more for the padded data in the TF code), and that last frame would be cut of (res[:new_time] in TF, i.e. slice in RF).

Alternatively, we could also use masked_select. We need to make sure that the resulting packed dim has actually also padded frames and can be unpacked just by reshaping, i.e. split_dims. So we need to calculate the new max size, and then add dummy padded frames such that it fits. To guarantee that this is always possible, it should be concatenated to the source mask and then also to the source itself. (E.g. imagine mask (B,T) [[1 1 1 0 ] [0 0 0 1]], that needs to be extended by two dummy frames to (B,T) [[1 1 1 0 0 0] [0 0 0 1 1 1]].) However, that causes a copy of the source, which might be large, which we would want to avoid. And I'm not sure that there is a simple other way.

Another question: Should this be a new function? Or part of the existing masked_select?

The existing masked_select has dims as argument, which are the dims which are going to be packed into a single resulting dim.

So, if dims is also just a single dim, e.g. the T from source [B,T,D] and mask [B,T], that would result in a new dim T'.

Is this always non-ambiguous?

How does the current implementation actually handle this?

Should this logic be implemented in pure RF or inside the backend? I think pure RF makes more sense.

Is this just always the case when len(dims) == 1? So do the current packing logic for len(dims) >= 2, and for len(dims) == 1, do the logic as described here.