
Official implementation of learning to terminate in object navigation, accepted to ACML 2023.

Primary LanguagePython

Learning to Terminate in Object Navigation

Yuhang Song, Anh Nguyen, Chun-Yi Lee

PyTorch implementation of our ACML 2023 paper Learning to Terminate in Object Navigation in AI2-THOR environment. This implementation is modified based on SAVN and MJOLNIR_O. Please refer to our paper for more details.

DITA Visualization Demo


The offline data can be found here.

"data.zip" (~5 GB) contains everything needed for evalution. Please unzip it and put it into the MJOLNIR folder.

For training, please also download "train.zip" (~9 GB), and put all "Floorplan" folders into ./data/thor_v1_offline_data


Note that DITA needs to specify a different agent_type in both training and testing.


python main.py --eval \
    --test_or_val test \
    --episode_type TestValEpisode \
    --load_model pretrained_models/DITA.dat \
    --model DITA \
    --results_json dita.json \
    --gpu-ids 0 \
    --load_JG_model pretrained_models/JudgeModel.dat \
    --agent_type SupervisedNavigationAgent

Evaluating the DITA model result in auto-generations of action log files for visulization.


If you have trained other models ("SAVN" or "GCN" or "MJOLNIR_O"), please evaluate them using the following command.

python main.py --eval \
    --test_or_val test \
    --episode_type TestValEpisode \
    --load_model [model_name] \
    --model MJOLNIR_O \
    --results_json mjolnir_o.json \
    --gpu-ids 0 \
    --agent_type NavigationAgent
    --judge_model False

Other model options are "SAVN" or "GCN" or "MJOLNIR_O".


Note that our visualization only supports DITA model.

cd visualization
python visualization.py --actionList ../dita_vis.log


Note that DITA needs to specify a different agent_type in both training and testing.


python main.py \
    --title mjolnir_train \
    --model MJOLNIR_O \
    --gpu-ids 0\
    --workers 8
    --vis False
    --save-model-dir trained_models
    --agent_type NavigationAgent

Other model options are "SAVN" or "GCN" or "MJOLNIR_O".


python main.py \
    --title DITA_training \
    --model DITA \
    --gpu-ids 0\
    --workers 8
    --vis False
    --save-model-dir trained_models
    --agent_type SupervisedNavigationAgent