Generate idx for train/val/test dataloaders from raster mask values
LucaRom opened this issue · 1 comments
Here's an implementation to create sets of idx for train/val and test using a raster mask where the pixels value are equal to the desired zone/dataloader (ex.: 0 = train/val, 1 = test). It can be usefull in some cases (ie : when you need specific area for validation or test, when you use overlapping tiles to train, etc).
It's a basic implementation so I think it can be easily adapted, modified or improved.
This function generates train/val and test indexes from a given image file. The function uses rasterio
library to iterates through all the windows in the image using a given window size (wd_size
) and step size (step_size
). For each window, the function reads the mask and zones arrays and determines the number of pixels that belong to the test set and non-habitable (background) pixels. Based on these counts, the function assigns the window to either the train/val set, test set or removes the window from the set.
The function also has an option to trim the background pixels (here class 7). Finally, the function saves the train/val and test indexes to numpy
files.
To use this code, you need to provide the path to the image file (img_path
) that you want to generate the indexes for (here its a 2 channels image, but you can use only 1 channel for zone and remove code related to "mask" if you do not need to manipulate tiles according to labels.
# Generate train/val and test indexes from zones and mask
def generate_indexes(img_path, wd_size, step_size, trim_background=True):
with rasterio.open(img_path) as ds:
windows_num = len(list(iter_windows(ds, step_size, wd_size, wd_size, strict_shape=False)))
test_indices = []
train_indices = []
full_nh_tiles_idx = []
bad_conf_list = []
removed = []
for idx, a_window in tqdm(enumerate(iter_windows(ds, step_size, wd_size, wd_size, strict_shape=False)), total=windows_num, desc='Windows'):
mask = ds.read(1, window=a_window)
zones = ds.read(2, window=a_window)
nh_pixels = np.count_nonzero(mask == 7)
test_pixels = np.count_nonzero(zones == 1)
tile_pixels_num = wd_size * wd_size
if a_window.col_off + 256 > ds.meta['width'] or a_window.row_off + 256 > ds.meta['height']:
removed.append(idx)
else:
if test_pixels == tile_pixels_num:
test_indices.append(idx)
elif trim_background and nh_pixels == tile_pixels_num:
full_nh_tiles_idx.append(idx)
else:
train_indices.append(idx)
print('Train val idx len:', len(train_indices))
print('Test idx len :', len(test_indices))
print('Removed idx len :', len(removed))
print('Bad_conf len :', len(bad_conf_list))
print('Total kept idx :', len(train_indices) + len(test_indices))
print('Total number of idx :', len(train_indices) + len(test_indices) + len(removed) + len(bad_conf_list))
print('Full NH idx len :', len(full_nh_tiles_idx))
trainval_idx_path = 'results/trainval_idx'
test_idx_path = 'results/test_idx'
np.save(trainval_idx_path, train_indices)
np.save(test_idx_path, test_indices)
Running the function specifying your mask image :
img_path = 'results/testzone_v03.tif'
generate_indexes(img_path, wd_size=256, step_size=128, trim_background=False)
Then you can use SubsetRandomSampler()
to load created idx sets and use it in the dataloader.
Note that the validation set here is still pulled from the train_idx_values, but it should be avoid if your tiles are overlapping.
trainval_idx_lst = np.load('results/trainval_idx.npy')
test_idx_lst = np.load('results/test_idx.npy')
shuffled_trainval = np.random.permutation(trainval_idx_lst)
val_size = round(len(trainval_idx_lst)*0.1) #10%
val_idx = shuffled_trainval[:val_size]
train_idx = [x for x in shuffled_trainval if x not in val_idx]
test_idx = test_idx_lst
# Creating sampler callable in dataloaders
train_sampler = SubsetRandomSampler(train_idx)
val_sampler = SubsetRandomSampler(val_idx)
test_sampler = SubsetRandomSampler(test_idx)
# Returning data loaders
train_loader = DataLoader(train_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory,sampler=train_sampler)
val_loader = DataLoader(train_ds, batch_size=batch_size, num_workers=0, pin_memory=pin_memory, sampler=val_sampler)
test_loader = DataLoader(train_ds_test, batch_size=1, num_workers=0, pin_memory=False, sampler=test_sampler)
return train_loader, val_loader, test_loader
Closed for now. Might be used later.