Heracles: Hybrid State space Model and Transformer (Hierarchical) Model
- PyTorch 1.10.0+
- Python3.8
- CUDA 10.1+
- timm==0.4.5
- tlt==0.1.0
- pyyaml
- apex-amp
Train Heracles small model
python3 -m torch.distributed.launch \
--nproc_per_node=8 \
--nnodes=1 \
--node_rank=0 \
--master_addr="localhost" \
--master_port=12346 \
--use_env main.py --config configs/heracles/heracles_s.py --data-path /export/home/dataset/imagenet --epochs 310 --batch-size 128 \
--token-label --token-label-size 7 --token-label-data /export/home/dataset/imagenet/label_top5_train_nfnet
Train Heracles Base model
python3 -m torch.distributed.launch \
--nproc_per_node=8 \
--nnodes=1 \
--node_rank=0 \
--master_addr="localhost" \
--master_port=12346 \
--use_env main.py --config configs/heracles/heracles_b.py --data-path /export/home/dataset/imagenet --epochs 310 --batch-size 128 \
--token-label --token-label-size 7 --token-label-data /export/home/dataset/imagenet/label_top5_train_nfnet
Train Heracles Large model
python3 -m torch.distributed.launch \
--nproc_per_node=8 \
--nnodes=1 \
--node_rank=0 \
--master_addr="localhost" \
--master_port=12346 \
--use_env main.py --config configs/heracles/heracles_l.py --data-path /export/home/dataset/imagenet --epochs 310 --batch-size 128 \
--token-label --token-label-size 7 --token-label-data /export/home/dataset/imagenet/label_top5_train_nfnet