Author: Cheng Xn.
This document provides an overview and explanation of the Python script designed for Unsupervised neural networks DIC. The script includes the following key components:
- Model definition and setup
- Data loading and preprocessing
- Training, validation, and testing procedures
- Loss and accuracy tracking
- Checkpoint saving and loading
- Visualization and result saving
The script starts with the necessary import statements for various Python packages like time
, math
, random
, shutil
, pdb
, scipy
, numpy
, torch
, matplotlib
, pandas
, and specific utility modules for data reading and model definition.
The Args
class defines the parameters and configurations for the training process, including paths for data and checkpoints, hyperparameters like learning rate and batch size, and options for GPU/CUDA usage.
The imagesc
function is used for visualizing the results of the DL-DIC model. It plots the predicted and real displacement fields for a subset of the data.
This function adjusts the learning rate for the optimizer based on the current epoch and the specified learning rate schedule.
The cls_train
function handles the training of the model for one epoch. It computes the loss, performs backpropagation, and updates the model's weights.
This function performs validation of the model on a validation dataset and computes the average loss.
Similar to cls_validate
, the cls_test
function tests the model on a test dataset and computes the average loss.
This function sets the random seeds for reproducibility of the results.
The main execution block of the script initializes the model, criterion, optimizer, and data loaders for training, validation, and testing datasets. It also handles checkpoint loading if resuming from a previous training session.
The training loop iterates over the specified number of epochs, calling cls_train
, cls_validate
, and cls_test
functions in each iteration. It also saves checkpoints at regular intervals and records the losses.
The script creates a pandas DataFrame with the recorded losses and saves it to a CSV file.
To use this script, ensure that the required Python packages are installed, and the data is organized according to the paths specified in the Args_cxn
class. Adjust the parameters in the Args_cxn
class as needed for your specific use case. Run the script in an environment where PyTorch and the other dependencies are available.
- The visualization part of the script (
imagesc
function) is commented out and should be enabled if visualization is required. - Ensure that the paths for data and checkpoints are correctly set up before running the script.
- The script includes error handling for CUDA availability, which allows it to run on both CPU and GPU environments.