ImageFolder __getitem__ method returning incorrect horizontal_flip transformation parameter
Opened this issue · 1 comments
Current Behavior:
The following line performs a horizontal flip on the augmented image with a 50% probability, but due to self.hor_flip(aug_sample)
not being deterministic, the augmented image does not correspond to the hor_flip
parameter:
CGC/datasets/imagefolder_cgc_ssl.py
Line 163 in a66d872
Expected Behavior:
The hor_flip
parameter should be True
iff the augmented image is a flipped version of the sample (possibly with some crop).
This can be done by setting self.hor_flip = tvf.hflip
Steps To Reproduce:
The following code was used to visualize the tensors and verify that sometimes the parameter does not correspond to the augmented image:
import matplotlib.pyplot as plt
import numpy as np
def display_tensors(tensor1, tensor2, hor_flip):
fig, axs = plt.subplots(1, 2, figsize=(10, 10))
for i, tensor in enumerate([tensor1, tensor2]):
# Convert the tensor to numpy array
image_np = tensor.numpy()
# Scale the values to [0, 1] range
image_np = (image_np - image_np.min()) / (image_np.max() - image_np.min())
# Transpose the numpy array if necessary
if image_np.shape[0] == 3: # Check if the image tensor is in the format (channels, height, width)
image_np = np.transpose(image_np, (1, 2, 0)) # Transpose to (height, width, channels)
# Display the image
axs[i].imshow(image_np)
axs[i].set_title(f"Flipped? {hor_flip}")
plt.show(block=True)
Anything else:
I am using your CGC paper for reference: https://arxiv.org/pdf/2110.00527.pdf
Hi @patrik-bartak , Yes, this seems to be an issue wherein the image might not be horizontally flipped (due to default prob of 50%) when we set the hor_flip parameter to be True. The probability of this occurring is 0.5*0.5 = 0.25 times, which is probably why this didn't lead to significant degradation in the final trained model. As a fix, the initialization for self.hor_flip should include a parameter to set probability as 1, i.e. the following line
CGC/datasets/imagefolder_cgc_ssl.py
Line 112 in a66d872
should be changed to
self.hor_flip = transforms.RandomHorizontalFlip(p=1.0)