[Feature Request] Passing custom activation functon in policy_kwargs
paolodelia99 opened this issue ยท 2 comments
paolodelia99 commented
๐ Feature
Possibility to pass a flax (from the flax.linen.activation
module) activation function when creating a sbx
model, through the policy_kwargs
argument.
Motivation
In the current implementation of sbx
, users are unable to pass custom activation functions when creating a model. This limitation restricts flexibility and may not suit all users' needs.
Pitch
Example:
policy_kwargs = dict(activation_fn=my_custom_activation_fn, net_arch=dict(pi=[64, 64], qf=[64, 64]))
model = TD3("MlpPolicy",
env,
policy_kwargs=policy_kwargs,
verbose=1)
Idea on how to implement it
Add attribute activation_fn
to the underlying classes that are composing the policy (like in Critic and Actor in t3d/policy.py
)
araffin commented
Hello,
sounds reasonable, would you contribute such feature?
paolodelia99 commented
Sure.