LucasAlegre/sumo-rl

ObservationFunction changes not applied

TGW795 opened this issue · 4 comments

TGW795 commented

Hi.

I've been trying to writing a code to experiment with MARL using sumo-rl.parallel_env() and making my own ObservationFunction, but any changes I made do not apply. It was stated in README that we can use our original ObservationFunction by defining it in observations.py and passing it to the environment constructor, and I followed this flow. Is there anything I am doing wrong? (I'm using the example code of PettingZoo Multi-Agent API)

Thank you.

Hi,

Can you share your code?

TGW795 commented

Sure. This is my code and changes. (I've been checking the behavior of several observation values, and I've not written about iteration part.)

  • sumo_rl/environment/observations.py
import numpy as np
from gymnasium import spaces

from .traffic_signal import TrafficSignal


class ObservationFunction:
    """Abstract base class for observation functions."""

    def __init__(self, ts: TrafficSignal):
        """Initialize observation function."""
        self.ts = ts

    @abstractmethod
    def __call__(self):
        """Subclasses must override this method."""
        phase_id = [1 if self.ts.green_phase == i else 0 for i in range(self.ts.num_green_phases)]
        observation = np.array(phase_id, dtype=np.float32)
        return observation

    @abstractmethod
    def observation_space(self):
        """Subclasses must override this method."""
        return spaces.Box(
            low=np.zeros(self.ts.num_green_phases, dtype=np.float32),
            high=np.ones(self.ts.num_green_phases, dtype=np.float32),


class DefaultObservationFunction(ObservationFunction):
    """Default observation function for traffic signals."""

    def __init__(self, ts: TrafficSignal):
        """Initialize default observation function."""
        super().__init__(ts)

    def __call__(self) -> np.ndarray:
        """Return the default observation."""
        phase_id = [1 if self.ts.green_phase == i else 0 for i in range(self.ts.num_green_phases)]  # one-hot encoding
        min_green = [0 if self.ts.time_since_last_phase_change < self.ts.min_green + self.ts.yellow_time else 1]
        density = self.ts.get_lanes_density()
        queue = self.ts.get_lanes_queue()
        observation = np.array(phase_id + min_green + density + queue, dtype=np.float32)
        return observation

    def observation_space(self) -> spaces.Box:
        """Return the observation space."""
        return spaces.Box(
            low=np.zeros(self.ts.num_green_phases + 1 + 2 * len(self.ts.lanes), dtype=np.float32),
            high=np.ones(self.ts.num_green_phases + 1 + 2 * len(self.ts.lanes), dtype=np.float32),
        )
  • sumo_rl/environment/env.py(#L99)
        observation_class: ObservationFunction = ObservationFunction,

I thought that we would obtain only phase_id by these modifications, but in fact, I got values defined in DefaultObservationFunction. (I've checked this issue by running a code same as an example of PettingZoo Multi-Agent API.)

ObservationFunction is an abstract class, you should not modify it. You have to create a new class that implements the abstract methods:

class MyObservationFunction(ObservationFunction):

    def __init__(self, ts: TrafficSignal):
        self.ts = ts

    def __call__(self):
        phase_id = [1 if self.ts.green_phase == i else 0 for i in range(self.ts.num_green_phases)]
        observation = np.array(phase_id, dtype=np.float32)
        return observation

    def observation_space(self):
        return spaces.Box(
            low=np.zeros(self.ts.num_green_phases, dtype=np.float32),
            high=np.ones(self.ts.num_green_phases, dtype=np.float32),

# In your experiment file:
env = sumo_rl.env(..., observation_class=MyObservationFunction)
TGW795 commented

It worked! Thank you!
This was just an elementary mistake on my part :)