An error in loss.py
Closed this issue · 1 comments
SilvesterYu commented
Hi, just wanted to point out that when I try to run the code for MinkLoc3D, I need to do the following changes.
In loss.py
, the mask
variables need to be converted from Tensor to uint8 using
mask = (mask > 0).type(torch.uint8)
Please see my edits below.
def get_max_per_row(mat, mask):
# -- Do casting to get uint.8 -- #
mask = (mask > 0).type(torch.uint8)
# print(type(mask))
non_zero_rows = torch.any(mask.bool(), dim=1)
mat_masked = mat.clone()
mat_masked[~mask] = 0
return torch.max(mat_masked, dim=1), non_zero_rows
def get_min_per_row(mat, mask):
# -- Do casting to get uint.8 -- #
mask = (mask > 0).type(torch.uint8)
non_inf_rows = torch.any(mask, dim=1)
mat_masked = mat.clone()
mat_masked[~mask] = float('inf')
return torch.min(mat_masked, dim=1), non_inf_rows
SilvesterYu commented
Hi, sorry for this. The error I previously got was perhaps caused by something else, and not the functions in loss.py. I changed it back to what you had before.