git clone https://github.com/zhengchen1999/NTIRE2024_ImageSR_x4.git
- Select the model you would like to test from
run.sh
CUDA_VISIBLE_DEVICES=0 python test_demo.py --data_dir [path to your data dir] --save_dir [path to your save dir] --model_id 0
- Be sure the change the directories
--data_dir
and--save_dir
. - We provide three baselines (team00): RFDN (default), SwinIR, and DAT. The code and pretrained models of the three models are provided. Switch models (default is DAT) through commenting the code in test_demo.py. Three baselines are all test normally with
run.sh
.
- Be sure the change the directories
- Register your team in the Google Spreadsheet and get your team ID.
- Put your the code of your model in
./models/[Your_Team_ID]_[Your_Model_Name].py
- Please add only one file in the folder
./models
. Please do not add other submodules. - Please zero pad [Your_Team_ID] into two digits: e.g. 00, 01, 02
- Please add only one file in the folder
- Put the pretrained model in
./model_zoo/[Your_Team_ID]_[Your_Model_Name].[pth or pt or ckpt]
- Please zero pad [Your_Team_ID] into two digits: e.g. 00, 01, 02
- Note: Please provide a download link for the pretrained model, if the file size exceeds 100 MB. Put the link in
./model_zoo/[Your_Team_ID]_[Your_Model_Name].txt
: e.g. team00_dat.txt
- Add your model to the model loader
./test_demo/select_model
as follows:elif model_id == [Your_Team_ID]: # define your model and load the checkpoint
- Note: Please set the correct data_range, either 255.0 or 1.0
- Send us the command to download your code, e.g,
git clone [Your repository link]
- We will do the following steps to add your code and model checkpoint to the repository.
from utils.model_summary import get_model_flops
from models.team00_DAT import DAT
model = DAT()
input_dim = (3, 256, 256) # set the input dimension
flops = get_model_flops(model, input_dim, False)
flops = flops / 10 ** 9
print("{:>16s} : {:<.4f} [G]".format("FLOPs", flops))
num_parameters = sum(map(lambda x: x.numel(), model.parameters()))
num_parameters = num_parameters / 10 ** 6
print("{:>16s} : {:<.4f} [M]".format("#Params", num_parameters))
This code repository is release under MIT License.