NathanGavenski/IL-Datasets

[Feature] Support visual environments

Closed this issue · 1 comments

Currently, the code for all methods does not support CNN-based policies.

class Method(ABC):
    """Base class for all methods."""

    __version__ = "1.0.0"
    __author__ = "Nathan Gavenski"
    __method_name__ = "Abstract Method"

    def __init__(
        self,
        environment: Env,
        environment_parameters: Dict[str, Any],
        discrete_loss: nn.Module = nn.CrossEntropyLoss,
        continuous_loss: nn.Module = nn.MSELoss,
        optimizer_fn: optim.Optimizer = optim.Adam,
    ) -> None:
        """Initialize base class."""
        super().__init__()
        self.environment = environment
        self.discrete = isinstance(environment.action_space, spaces.Discrete)
        self.discrete |= isinstance(environment.action_space, gym_spaces.Discrete)
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        observation_size = environment.observation_space.shape[0]

        if self.discrete:
            action_size = environment.action_space.n
            self.loss_fn = discrete_loss()
        else:
            action_size = environment.action_space.shape[0]
            self.loss_fn = continuous_loss()

        self.policy = MLP(observation_size, action_size)

        self.optimizer_fn = optimizer_fn(
            self.policy.parameters(),
            **environment_parameters
        )

We should change to be more dynamic, but I want to avoid a register for all environments, whether they are visual- or vector-based states. However, I think this is gonna be solved once we reach Atari environments support.

Solved by #17