Cannot load DQN model
fqidz opened this issue · 1 comments
First of all: Sorry if this doesn't belong here. I'll post this on the stable-baselines3 github if so.
Hello I'm a beginner and I'm facing this problem where I cant load the saved DQN model. I trained it using libsumo and I want to load the saved model so that I can see its performance using sumo-gui. I am also not sure if this is an issue with the environment or with stable-baselines3.
When I try to load the model using model = DQN.load('./output/model_saved.zip', env=env)
, I get the error:
/home/<name>/Documents/sumo-traffic-capstone/utils/stable-baselines3/stable_baselines3/common/save_util.py:167: UserWarning: Could not deserialize object lr_schedule. Consider using `custom_objects` argument
to replace this object.
Exception: code() argument 13 must be str, not int
warnings.warn(
/home/<name>/Documents/sumo-traffic-capstone/utils/stable-baselines3/stable_baselines3/common/save_util.py:167: UserWarning: Could not deserialize object exploration_schedule. Consider using `custom_objects`
argument to replace this object.
Exception: code() argument 13 must be str, not int
warnings.warn(
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Retrying in 1 seconds
....
Here's the print_system_info
if it's relevant:
== CURRENT SYSTEM INFO ==
- OS: Linux-6.7.10-200.fc39.x86_64-x86_64-with-glibc2.38 # 1 SMP PREEMPT_DYNAMIC Mon Mar 18 18:56:52 UTC 2024
- Python: 3.12.2
- Stable-Baselines3: 2.3.0a5
- PyTorch: 2.2.1+cu121
- GPU Enabled: False
- Numpy: 1.26.4
- Cloudpickle: 3.0.0
- Gymnasium: 0.29.1
== SAVED MODEL SYSTEM INFO ==
- OS: Linux-5.15.133.1-microsoft-standard-WSL2-x86_64-with-glibc2.35 # 1 SMP Thu Oct 5 21:02:42 UTC 2023
- Python: 3.10.12
- Stable-Baselines3: 2.3.0a5
- PyTorch: 2.2.1+cu121
- GPU Enabled: True
- Numpy: 1.26.4
- Cloudpickle: 3.0.0
- Gymnasium: 0.29.1
I tried adding custom_object, model = DQN.load('./output/model_saved.zip', env=env, custom_objects={'lr_schedule': 0.0, 'exploration_schedule': 0.0})
, as per the example here, but it just removed the warning but it still didn't load the save model and started learning from the beginning.
I'm using this as the env:
env = SumoEnvironment(net_file='./sumo-things/net.net.xml',
route_file='./sumo-things/main.rou.xml',
out_csv_name='./output/dqn-stats/traffic_sim',
reward_fn=my_reward_fn,
yellow_time=4,
time_to_teleport=2000,
use_gui=use_gui,
single_agent=True,
num_seconds=num_seconds,
)
Also another thing to note is I cloned this repo and did pip install -e .
because I wanted to edit one part of the env.py to output the queue length in the csv file, but I'm not sure if it's relevant:
def _get_per_agent_info(self):
stopped = [self.traffic_signals[ts].get_total_queued()
for ts in self.ts_ids]
accumulated_waiting_time = [
sum(self.traffic_signals[ts].get_accumulated_waiting_time_per_lane()) for ts in self.ts_ids
]
average_speed = [self.traffic_signals[ts].get_average_speed()
for ts in self.ts_ids]
+ total_queued = [self.traffic_signals[ts].get_total_queued()
+ for ts in self.ts_ids]
info = {}
for i, ts in enumerate(self.ts_ids):
info[f"{ts}_stopped"] = stopped[i]
info[f"{ts}_accumulated_waiting_time"] = accumulated_waiting_time[i]
info[f"{ts}_average_speed"] = average_speed[i]
+ info[f"{ts}_queue_length"] = total_queued[i]
info["agents_total_stopped"] = sum(stopped)
info["agents_total_accumulated_waiting_time"] = sum(
accumulated_waiting_time)
return info