This project implements an image segmentation model using U-Net architecture with a ResNet34 encoder. The model is trained to segment images stored in TIFF format and masks stored in PNG format. The implementation leverages PyTorch and the segmentation-models-pytorch
library for model building and training.
- Preprocessing of images and masks
- Binary segmentation using U-Net architecture
- Evaluation metrics: Loss, Dice score, F1 score, and Accuracy
- Flask deployment for serving the model
Make sure you have the following libraries installed:
- Python 3.x
- PyTorch
- segmentation-models-pytorch
- torchmetrics
- rasterio
- tifffile
- torchvision
- matplotlib
- scikit-learn
- Flask
Ensure you have your dataset organized as follows:
- Place your
.tif
images in thedata/images/
directory. - Place your corresponding
.png
masks in thedata/labels/
directory.
To run this project, you'll need to install the required packages. You can do this using pip:
pip install -r requirements.txt
Instructions for usage go here.
To train the model, run the following command in your terminal:
python src/train.py
This will initiate the training process, using the images and masks you prepared earlier. The training loop will print the training and validation metrics, including loss, Dice score, F1 score, and accuracy.
To evaluate the trained model on a test set, ensure your test dataset is similarly organized as the training data and run:
python src/evaluate.py
This will output the test loss, Dice score, F1 score, and accuracy.
Upon evaluation, you can expect output metrics similar to the following:
Test Loss: 0.6287
Test Dice Score: 0.8970
Test F1 Score: 0.8970
Test Accuracy: 0.9530
To deploy the model as a web application using Flask, follow these steps:
- Install Flask:
pip install Flask
- Create a app.py file in the root of your project with the following content:
from flask import Flask, request, jsonify import torch from torchvision import transforms import segmentation_models_pytorch as smp from PIL import Image import numpy as np app = Flask(__name__) # Load your trained model model = smp.Unet( encoder_name="resnet34", encoder_weights="imagenet", in_channels=3, classes=1, ) model.load_state_dict(torch.load('path_to_your_trained_model.pth')) model.eval() # Define the image transformations transform = transforms.Compose([ transforms.Resize((256, 256)), # Resize the input image to the model's expected input size transforms.ToTensor(), # Convert the image to a tensor ]) @app.route('/predict', methods=['POST']) def predict(): # Check if an image file was uploaded if 'file' not in request.files: return jsonify({'error': 'No file part'}), 400 file = request.files['file'] # If the user does not select a file, the browser submits an empty file without a filename if file.filename == '': return jsonify({'error': 'No selected file'}), 400 # Process the input image img = Image.open(file.stream).convert("RGB") # Read the image and convert to RGB img = transform(img) # Apply the transformations img = img.unsqueeze(0) # Add a batch dimension # Perform inference with torch.no_grad(): output = model(img) prediction = torch.sigmoid(output).squeeze().numpy() # Apply sigmoid and convert to numpy array # Convert the prediction to a binary mask binary_mask = (prediction > 0.5).astype(np.uint8) # Threshold the output return jsonify({'prediction': binary_mask.tolist()}) # Return the prediction as a list if __name__ == '__main__': app.run(debug=True)
- Run the Flask app: In your terminal, run
python app.py
This project is licensed under the MIT License. See the LICENSE file for more information.
- U-Net: Convolutional Networks for Biomedical Image Segmentation
- Segmentation Models PyTorch: A library that provides implementations of segmentation models in PyTorch.