[Feature Request] Action Masking
Kang-SungKu opened this issue ยท 16 comments
Motivation
I recently started learning TorchRL, and creating a custom environment (using torchrl.envs.EnvBase) based on the documentation (https://pytorch.org/rl/reference/envs.html). For my environment, I would like to apply an action mask, such that the environment does not allow infeasible action based on the observation (for example, suppose the action is to choose a trump card, the number of A
is limited, such that it cannot be chosen once all the A
's are drawn). So far, I could not find a way to implement action masking for the environment, but it would be a convenient feature to implement similar environment.
Solution
It would be convenient if I can include a mask as a part of the observation_spec
, such that the environment can tell feasible/infeasible actions based on the observation (even when a random action is chosen). Currently, my environment cannot pass torchrl.envs.utils.check_env_specs()
since infeasible actions are chosen.
If it is not reasonable to implement this feature, any alternative way to implement an action mask is appreciated.
Checklist
I searched with the keyword 'Action Masking', but could not find relevant issues. Sorry if I missed something.
- [ X ] I have checked that there is no similar issue in the repo (required)
Please! I do need this feature so bad. Could anyone provide any RL library that supports action space?
You could do this in a TensorDictModule
.
class ActionMask(TensorDictModule):
...
def forward(self, td):
if td['observation'] == 1:
td['action'][:, 0] = 0
policy = Sequential(EGreedyWrapper(Sequential(mlp, QValueModule())), ActionMask())
@1030852813 , TensorFlow Agents supports action masking. It allows to pass mask to their policy
as the parameter named observation_and_action_constraint_splitter
, and policies implement action masking in different ways, for example:
QPolicy
applies a mask such that the q-values of infeasible actions are assumed to be negative infiniteRandomTFPolicy
utilizes a mask to create aMaskedCategorical
such that the action is sampled from it. It seems like torchrl also hasMaskedCategorical
(https://pytorch.org/rl/reference/generated/torchrl.modules.MaskedCategorical.html#torchrl.modules.MaskedCategorical) based on the same class.
The good thing about their implementation is that the policies are designed to use the same form of observation_and_action_constraint_splitter
. I am considering several options based on what @smorad mentioned to see what is reasonable.
Thanks for reporting this! We should enable this use case, you are totally right.
What is the domain of the actions? Should this work for continuous and discrete or just discrete?
Can you provide a small example of what you'd like to see as a transform for that? In the example you gave the mask is changing continuously, I guess that should be an input to the transform too?
Thank you for your response. I am considering a discrete action space at this point, but it would be useful if it is possible to define masks for continuous action space as well. I have not come up with the full list of transforms I need, so I am currently following the DQN example (https://pytorch.org/rl/tutorials/coding_dqn.html) without transform (in other words, empty Compose()
is used) to see what modification is required to make the action mask work. I plan to use transform RewardSum
if I can print the history of episode_reward
with logger
and recorder
(any guidance would be appreciated for this).
Environment Setting
Here is a simplified version of my environment that needs action mask (not the actual environment):
- A player has a fixed number of cards (say 6 cards), where there are two types of cards (say numeric cards
1
,2
,3
& alphabet cardsA
,B
,C
, and the number of numeric cards and the number of alphabet card are the same.). - In each step, the player chooses a card from each type (for example, 1/3/5-th step: choose from
1
,2
,3
while 2/4/6-th step: choose fromA
,B
,C
) and drop it. The dropped cards cannot be chosen for the rest of the game. - A game continues until all the cards are dropped, and The reward is determined based on the order of the cards (stochastic reward), given scoring rules.
- In this example, action mask is an array with 6 binary values, where each value represents the availability of each card at each step. Specifically, the mask is
[1 1 1 0 0 0]
at the 1st step and[0 0 0 1 1 1]
at the 2nd step, since the player is allowed to drop either a numeric or alphabet card. The action mask is changing as the player drops cards.
Modifications to make the action mask work
Given this environment, here is the list of modifications I needed to make to run the DQN example with action masking. I am new to Torch, so Feel free to let me know if there are better ways of implementation.
- It is necessary to redefine
rand_action
to utilize action masking. This is required as the collector (may) use the functionrollout
to collect frames, which generates random action satisfying theaction_spec
without considering the action mask. What I did is just keep sampling the random actions until it satisfies the action mask like below:
def rand_action(self, tensordict: TensorDictBase | None = None):
shape = tensordict.shape
infeasible = True
while infeasible:
action_onehot = self.action_spec.rand(shape=shape)
infeasible = False if np.dot(action_onehot, self._action_mask) > 0 else True
tensordict.update({ACTION: action_onehot})
return tensordict
-
make_env
function in the DQN example returns the environment asTransformedEnv
, butTransformedEnv
seems to useEnvBase.rand_action
even whenrand_action
is re-defined in the custom environment. To prevent this, the custom environment is used without conversion toTransformedEnv
, but should have the attributecustom_env.transform
to run the trainer (for example,custom_env.transform = Compose()
to use empty transform. -
For DQN actor, I wrote a
TensorDictModule
that re-sample the action after changing the Q-value of the infeasible actions (changing the Q value too much seems to lead to invalid Q-values, so I need to double-check it).
class ActionMaskingModule(TensorDictModuleBase):
def __init__(self):
self.in_keys=['action', 'action_value', 'chosen_action_value', ACTION_MASK]
self.out_keys=['action', 'action_value', 'chosen_action_value']
super().__init__()
def forward(self, tensordict: TensorDict):
# Retrieve mask and action_value
action = tensordict.get(self.in_keys[0], None)
action_value = tensordict.get(self.in_keys[1], None)
chosen_action_value = tensordict.get(self.in_keys[2], None)
action_mask = tensordict.get(self.in_keys[3], None).to(action_value.dtype)
if action.sum() == 0.0: exit() # Some exception
if 2 in action: exit() # Some exception
# Update action_value, and then update action & chosen_action_value
action_value_out = torch.where(condition=action_mask.to(torch.bool),
input=action_value,
other=action_value.min()-1.0)
action_out = (action_value_out == action_value_out.max(dim=-1, keepdim=True)[0]).to(torch.long)
chosen_action_value_out = (action_out * action_value_out).sum(-1, True)
# Update the tensordict to be returned
tensordict.update(dict(zip(self.out_keys, (action_out, action_value_out, chosen_action_value_out))))
return tensordict
The action mask module is attached after the QValueActor
as follows when the actor is created:
actor = SafeSequential(QValueActor(module=net,
in_keys=[DESIGN],
spec=env.action_spec,
action_value_key='action_value'),
ActionMaskingModule() ).to(device)
- Lastly, I needed to define a subclass
EGreedyWrapper
, similar to the reason I redefinedrand_action
for my environment, such that random actions are re-sampled until they satisfy the action mask (I will skip the implementation since it is similar to what I did forrand_action
).
Question and/or Suggestion
Based on the observation, I believe it would be helpful to have variants of OneHotDiscreteTensorSpec
and/or MultiDiscreteTensorSpec
which can incorporate a mask when a random action is sampled. This will address 1.
, 2.
, and 4.
at this same time, and eliminate the necessity to re-define all the random samplings in different locations.
Also, in my case, action mask is similar to one hot discrete tensor (should have at least one 1, except for the terminal state) but may have multiple 1's, so wonder if there is a spec which satisfies the requirements. I am using MultiDiscreteTensorSpec
, which allows sampling zero array, which requires another exception handling from my side (especially when a fake_tensordict is generated and fed to an actor)
For a specific type of actor (including Q-actor), it would be too much to implement action masking features since different algorithms may utilize action mask differently. I believe it is reasonable to add an argument mask
(or action_mask_key
) and specify where the mask should be implemented (probably with NotImplementedError
), such that one can easily customize the actors to utilize an action mask.
As I mentioned, I would be appreciated if there are better (or more desirable) ways of implementation.
Could you look at the PR above and let me know if that helps solving your problem?
Here's a notebook with some quick tests
https://gist.github.com/vmoens/95d6427fcb5fa5714291b3dbfa7daa15#file-action_mask-ipynb
Hi there, dropping by since it would be a very useful feature for our RL4CO library :)
So far, we have been dealing with the action_mask
internally and we do the random sampling similarly to how you managed it here with torch.multinomial
. As @Kang-SungKu suggested, it would be useful to add a masking key that you implemented here. A minor detail we would recommend to change is to default the mask_key
kwarg to action_mask
instead of mask
, since action_mask
seems to be more clear and widely used, as also done in Gymnasium/Gym in this example.
@vmoens , I re-installed torchrl (masked_action branch) and tensordict based on your notebook, and copy-pasted the script in the section Masked actions in env: ActionMask transform without modification.
I confirmed that rand_step(td)
draws infeasible action without applying the transform ActionMask()
(I assume this is intended behavior), but I still get the following error when I run rollout
with the environment transformed with ActionMask()
:
Traceback (most recent call last):
File "C:\{workspace_path}\test.py", line 45, in <module>
r = env.rollout(10)
File "C:\{conda_env_path}\lib\site-packages\torchrl\envs\common.py", line 1222, in rollout
tensordict = self.reset()
File "C:\{conda_env_path}\lib\site-packages\torchrl\envs\common.py", line 944, in reset
tensordict_reset = self._reset(tensordict, **kwargs)
File "C:\{conda_env_path}\lib\site-packages\torchrl\envs\transforms\transforms.py", line 696, in _reset
out_tensordict = self.transform._call(out_tensordict)
File "C:\{conda_env_path}\lib\site-packages\torchrl\envs\transforms\transforms.py", line 4461, in _call
mask = tensordict.get(self.in_keys[1])
AttributeError: 'DiscreteTensorSpec' object has no attribute 'get'
I get the same attribute error when I change action_spec from DiscreteTensorSpec
to OneHotDiscreteTensorSpec
as well. Can you let me know what would be the issue? I think something might be slightly different from your settings, or I missed something simple. I appreciate your time!
A minor detail we would recommend to change is to default the
mask_key
kwarg toaction_mask
instead ofmask
, sinceaction_mask
seems to be more clear and widely used, as also done in Gymnasium/Gym in this example.
I can see the point. A self-explanatory name and a clear purpose are great.
@vmoens , is there anything I can check with to address the AttributeError: 'DiscreteTensorSpec' object has no attribute 'get'
(maybe the version of specific dependencies)? Sorry if these are too simple things that I am missing. I really appreciate your time!
@vmoens , thank you for your help which made the implementation cleaner. I made several modifications to make the action_mask work with the dqn tutorial. I am still validating if there are any unhandled exceptions, and I would appreciate more efficient implementations.
Modified Locations (Note: masked_actions
branch of torchrl should be used to enable experimental masking features.)
-
The function
reset()
of the transformActionMask
should be fixed to returntensordict
rather thanaction_spec
to avoidAttributeError: 'DiscreteTensorSpec' object has no attribute 'get'
. This fix allows the transformed environment to runrollout(policy=None)
without an error. The corresponding location is as follows:
rl/torchrl/envs/transforms/transforms.py
Line 4477 in 7d291a7
-
Experimental masking features are not enabled for
EGreedyWrapper
. To make it use a mask, it is necessary to assignmask
to the action spec likespec.mask = tensordict['action_mask']
shown as follows:
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
tensordict = self.td_module.forward(tensordict)
...
if spec is not None:
if isinstance(spec, CompositeSpec):
spec = spec[self.action_key]
if 'action_mask' in tensordict.keys():
spec.mask = tensordict['action_mask']
out = cond * spec.rand().to(out.device) + (1 - cond) * out
else:
raise RuntimeError(
"spec must be provided by the policy or directly to the exploration wrapper."
)
action_tensordict.set(action_key, out)
return tensordict
The corresponding location is as follows:
- To make
QValueActor
(or other actors) use the mask, it is necessary to add an additionalTensorDictModule
applying action mask. I created a masked variant ofQValueActor
where an additionalTensorDictModule
is applied between q_net module (such asDuelingCnnDQNet
) andQValueModule
. The following code shows an intermediate modulemasking
is added beforeQValueModule
to apply a mask before evaluating argmax. Other than that, this class is the same as the originalQValueActor
.
class MaskedQValueActor(SafeSequential):
def __init__(
self,
module,
*,
in_keys=None,
spec=None,
safe=False,
action_space: Optional[str] = None,
action_value_key=None,
action_mask_key=None,
):
...
## Create q_net module
action_space, spec = _process_action_space_spec(action_space, spec)
...
spec[action_value_key] = None
spec["chosen_action_value"] = None
...
##### Create masking module start ######################
if action_mask_key is None:
action_mask_key = "action_mask"
in_keys_masking = [action_value_key, action_mask_key]
masking = TensorDictModule(
lambda action_value, action_mask: torch.where(action_mask, action_value, torch.finfo(action_value.dtype).min),
#lambda action_value, action_mask: torch.where(action_mask, action_value, action_value.min()-1.0),
in_keys=in_keys_masking,
out_keys=[action_value_key]
)
##### Create masking module end ########################
## Create q_value module
qvalue = QValueModule(
action_value_key=action_value_key,
out_keys=out_keys,
spec=spec,
safe=safe,
action_space=action_space,
)
super().__init__(module, masking, value)
- I needed to modify
DQNLoss
since it seems to expect a one-hot encoded action spec. I disabledaction = action.unsqueeze(-1)
in the following location to avoid adding a redundant axis to action.
Line 293 in 7d291a7
Thanks for this.
I'm on PTO these days so that's why I'm not actively participating in this thread (or others FWIW).
I will catch up as soon as I'm back! Sorry about that.
No problem. I will let you know if I have any updates about this.
The PR should be fixed now, the notebook works on my side. If you're happy with it we can merge it.
I confirmed that the notebook works on my side as well.
If possible, I suggest EGreedyWrapper
accept an optional argument like action_mask_key
, such that spec.mask
is assigned to apply the action mask when the key is passed to EGreedyWrapper
(see item 2. in my previous report).
I still need to make the same modification to DQNLoss
to make the DQN tutorial work (see item 4. in my previous report), when the action spec is DiscreteTensorSpec
or BinaryDiscreteTensorSpec
. I am not sure whether modifying it would affect other components, so it might require additional investigation.
I see that there is an ongoing discussion on check_env_specs
in PR. It would be convenient to make the env pass the check_env_specs
with the action mask incorporated.
I think PR can be merged when EGreedyWrapper
and check_env_specs
work with an action mask.
As discussed with @matteobettini, we think EGreedy and DQN compatibility is a separate issue that deserves to be addressed in independent PRs. Happy to tackle these asap.