pytorch/tensordict

[Feature Request] Using `pad_sequence(... return_mask=True)` leads to confusing debug sessions

kurt-stolle opened this issue · 7 comments

Motivation

Working on computer vision applications, I often encounter situations where certain object models contain a key called "masks", most commonly referring to something like semantic segmentation outputs.

The function pad_sequence has an argument return_mask which adds a key called also called "masks". This can lead to some confusing behavior or bugs, especially since this may go unnoticed and could be hard to trace back.

Currently, I cannot control the key that the padding mask is being written to. In general, there is no clear way to separate/modify keys that TensorDict internally uses and my own user keys.

Solution

Make the interface with return_mask=True return the valid-entry mask as a secondary output, e.g.

def pad_sequence(..., return_mask=False) -> TensorDict: ...  # padded tensordict
def pad_sequence(..., return_mask=True) -> tuple[TensorDict, TensorDict]: ... # padded tensordict, valid masks

Alternatives

  1. Propose to have all internally used keys exposed as constants and prepended with an underscore, e.g. export a variable KEY_MASK: Final = "_mask", to clearly differentiate internal keys from user keys.
  2. Propose to pass the key (with default "masks") as an argument, either by a value of str to return_mask or as a separate parameter.
  3. Copy-paste the method to my own codebase and change the key.

Additional context

I understand that this is a minor/niche request, but since PyTorch is central to many computer vision applications, I hope that the adoption of TensorDict could be helped by avoiding the most common keys that users would be inclined to use.

Checklist

  • [x ] I have checked that there is no similar issue in the repo (required)

Since it's a separate key, you could do

padded = pad_sequence(td, return_masks=True)
masks = padded.pop("masks")

That being said I see your point. The reason we pack it all together is that we can keep the function signature unchanged depending on its arguments. One thing I personally don't enjoy is when changing a keyword argument requires me to adapt the return type. For instance, I find the max function confusing in pytorch (tensor.max() returns a single value, tensor.max(1) returns a named tuple). With tensordict we can avoid that in almost all cases, since we can pack the results together.

@dtsaras what do you think about this?

That makes sense, though the initial reason for raising this was that there are many cases where a tensordict could already have a key called "masks", seeing as this is quite a common key (in vision applications). This would be overwritten and unrecoverable (?) after pad_sequence padding.

I personally use it as is and I like that all the data associated are packed in the same TD. I am not a big fan of separating the TD to the padded_TD and the associated masks. We can definitely adjust the function if the community has the need to adjust the masks key though. For example, the user can provide a padding_masks_key which can be set to "padding_masks" (in case "masks" is more likely to write over other data) or "masks" by default.

where a tensordict could already have a key called "masks", seeing as this is quite a common key (in vision applications)

What about

return_masks=True # sets "masks" key
return_masks="bibbidi-bobbidi-boo" # sets "bibbidi-bobbidi-boo" key
return_masks=False # sets nothing 

I guess this works but it certainly isn't pretty

haha yeah
It's the kind of "python isn't strongly typed so let me do whatever I want" solution :)

I guess most people won't really care about prettyness and most will just use the boolean

I made the changes and have added some changes to the test. Do you think it would be better if the option for the return_mask is a NestedKey instead of string? @kurt-stolle #739
We had talked before about adding also an option for selecting the keys to pad. Would you like to add this in this PR or should I make a separate one? @vmoens