This repository represents the final project of Reinforcement Learning Course from Skoltech University. It tackles a non-RL problem using Deep Reinforcement Learning. This project is mainly based on Active Object Localization with Deep Reinforcement Learning
main
- implementation according to original article (Pytorch+DQN) - based on Rayan Samy's code referenced belowresnet50
- same as main but feature extractor is pre-trained resnet50 (instead of vgg16 used in main)sb3_gym
- experimental impementation based on PPO algorithm from 'Stable Baselines 3' library (feature extractor is still pre-trained CNN)
Training.ipynb
is used to reproduce the training process of the model.Testing.ipynb
is used to reproduce the testing process of the model and visualize some examples of localization.Plotting.ipynb
is used to plot all graphs and charts shown above usingmedia
folder.media
is a folder to save all examples of localization and graphs.models
is a directory which is needed to keep the saved models weights after training. This is an example of the so-called folder.utils
is a directory contaning the following files:agent.py
: a wrapper for the per-class agent that contains the whole components of RL (ϵ-greedy policy, reward, ... etc).models.py
: a wrapper for the two main modules of the network: Feature Extractor and DQN.dataset.py
: a separate file for reading the dataset (train and val).tools.py
: a collection of useful functions, such as computing the metrics or extracting a specified class from the dataset.
The following figure shows the high-level diagram of the used DQN model architecture from the authors of the original work:
According to the paper, we used VGG-16 model as our pre-trained CNN on ImageNet dataset.
We used PASCAL VOC 2007 Dataset, a well-known dataset for object recognition. This dataset contains various images for 20 different classes, spanning from human beings and living creatures to vehicles and indoor objects. For the sake of our academic project, we trained the training set on a less number of classes, and used the validation set for testing.
P.S. The current version of the code only supports the offline dataset because the (mentioned) official website of the dataset was down as described at the start of training.ipynb
.
Referring to the above-mentioned original paper, we used AP (Average Precision) as our accuracy metric, side by side to Recall.
Limited by computational resources, the model shown above is trained on only 5 classes. The following two charts show the performance of the model according to different values of the threshold of IoU:
Average Precision | Recall |
---|---|
In addition to those charts, there's another humble chart to demonstrate the comparison to the original results as follows:
Kindly note that the model under study is trained using less data, and is also tested against a different dataset.
Frankly speaking, we would like to thank Rayan Samy for being our consultant as this project is inspired by his repository.