PyTorch implementation of "MedSRGAN: medical images super-resolution using generative adversarial networks"
import torch
from generator import Generator
from discriminator import Discriminator
generator = Generator(
in_channels= 3,
blocks= 8
)
discriminator = Discriminator(
in_channels= 3,
img_size= (256, 256)
)
To use the app, follow these steps:
-
Create the
custom_dataset
folder in your project directory. -
Create the
train_LR
andtrain_HR
subdirectories insidecustom_dataset
-
Run the following command in the terminal to train the model:
python main.py --LR_path custom_dataset/train_LR --GT_path custom_dataset/train_HR
This will train the MedSRGAN model using your medical image dataset. Adjust the hyperparameters in the
main.py
file as needed. -
After training, you can test the model on new images using:
python tester.py
Make sure to input the path of the test image when prompted.
-
View the output result as
enhanced_output.jpeg
in your root directory.