[Feature] Support visual environments
Closed this issue · 1 comments
NathanGavenski commented
Currently, the code for all methods does not support CNN-based policies.
class Method(ABC):
"""Base class for all methods."""
__version__ = "1.0.0"
__author__ = "Nathan Gavenski"
__method_name__ = "Abstract Method"
def __init__(
self,
environment: Env,
environment_parameters: Dict[str, Any],
discrete_loss: nn.Module = nn.CrossEntropyLoss,
continuous_loss: nn.Module = nn.MSELoss,
optimizer_fn: optim.Optimizer = optim.Adam,
) -> None:
"""Initialize base class."""
super().__init__()
self.environment = environment
self.discrete = isinstance(environment.action_space, spaces.Discrete)
self.discrete |= isinstance(environment.action_space, gym_spaces.Discrete)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
observation_size = environment.observation_space.shape[0]
if self.discrete:
action_size = environment.action_space.n
self.loss_fn = discrete_loss()
else:
action_size = environment.action_space.shape[0]
self.loss_fn = continuous_loss()
self.policy = MLP(observation_size, action_size)
self.optimizer_fn = optimizer_fn(
self.policy.parameters(),
**environment_parameters
)
We should change to be more dynamic, but I want to avoid a register for all environments, whether they are visual- or vector-based states. However, I think this is gonna be solved once we reach Atari environments support.
NathanGavenski commented
Solved by #17