KamilZywanowski/MinkLoc3D-SI

Question about pos and neg mask

Closed this issue · 1 comments

Hi, thanks for your great work.
I'm a new learner and have been confused at make_collate_fn part when reading your code:

def make_collate_fn(dataset: OxfordDataset, version, dataset_name, mink_quantization_size=None):
# set_transform: the transform to be applied to all batch elements
def collate_fn(data_list):
# Constructs a batch object
clouds = [e[0] for e in data_list]
labels = [e[1] for e in data_list]
batch = torch.stack(clouds, dim=0) # Produces (batch_size, n_points, point_dim) tensor
if dataset.set_transform is not None:
# Apply the same transformation on all dataset elements
batch = dataset.set_transform(batch)
if mink_quantization_size is None:
# Not a MinkowskiEngine based model
batch = {'cloud': batch}
else:
if version == 'MinkLoc3D':
coords = [ME.utils.sparse_quantize(coordinates=e, quantization_size=mink_quantization_size)
for e in batch]
coords = ME.utils.batched_coordinates(coords)
# Assign a dummy feature equal to 1 to each point
# Coords must be on CPU, features can be on GPU - see MinkowskiEngine documentation
feats = torch.ones((coords.shape[0], 1), dtype=torch.float32)
elif version == 'MinkLoc3D-I':
coords = []
feats = []
for e in batch:
c, f = ME.utils.sparse_quantize(coordinates=e[:, :3], features=e[:, 3].reshape([-1, 1]),
quantization_size=mink_quantization_size)
coords.append(c)
feats.append(f)
coords = ME.utils.batched_coordinates(coords)
feats = torch.cat(feats, dim=0)
elif version == 'MinkLoc3D-S':
coords = []
for e in batch:
# Convert coordinates to spherical
spherical_e = torch.tensor(to_spherical(e.numpy(), dataset_name), dtype=torch.float)
c = ME.utils.sparse_quantize(coordinates=spherical_e[:, :3], quantization_size=mink_quantization_size)
coords.append(c)
coords = ME.utils.batched_coordinates(coords)
feats = torch.ones((coords.shape[0], 1), dtype=torch.float32)
elif version == 'MinkLoc3D-SI':
coords = []
feats = []
for e in batch:
# Convert coordinates to spherical
spherical_e = torch.tensor(to_spherical(e.numpy(), dataset_name), dtype=torch.float)
c, f = ME.utils.sparse_quantize(coordinates=spherical_e[:, :3], features=spherical_e[:, 3].reshape([-1, 1]),
quantization_size=mink_quantization_size)
coords.append(c)
feats.append(f)
coords = ME.utils.batched_coordinates(coords)
feats = torch.cat(feats, dim=0)
batch = {'coords': coords, 'features': feats}
# Compute positives and negatives mask
# dataset.queries[label]['positives'] is bitarray
positives_mask = [[dataset.queries[label]['positives'][e] for e in labels] for label in labels]
negatives_mask = [[dataset.queries[label]['negatives'][e] for e in labels] for label in labels]
positives_mask = torch.tensor(positives_mask)
negatives_mask = torch.tensor(negatives_mask)
# Returns (batch_size, n_points, 3) tensor and positives_mask and
# negatives_mask which are batch_size x batch_size boolean tensors
return batch, positives_mask, negatives_mask
return collate_fn

Assuming batch size = 2 and the structure of training data is as follows:

{
"0": { "query": path/to/file/xxx.bin
        "positives": 1, 116, 117, 345, 346, ...
        "negatives": 18671, 15181, 8746, 2052, 7919...
       }
"1": { "query": path/to/file/xxx.bin
        "positives": 0, 2, 117, 118, 346, ...
        "negatives": 10283, 20550, 17938, 4424, 8452...
       }
...
}

For query 0, labels is [0, 1], its corresponding positives_masks is dataset.queries[0]['positives'][0] , dataset.queries[0]['positives'][1]. Because dataset.queries has been binarized to range(len(index)) and has excluded query_idx itself, in this case of small batch size, positives_masks should be [1,1]. Similarly, negatives_masks should be [0,0].

For query i and query i+1, labels is [i, i+1], and then get the i'th and (i+1)'th pos and neg bit-label. Actually we always get the next batch_size labels from i.

  • Is there anything wrong with my understanding?
  • What is the meaning of positives_mask and negatives_mask?
  • why not feed pos and neg into network and calculate the embedding distance between anchor and them?

Could you please give me some brief explanation?
Thanks in advance.

could refer to this