Add meta data input fields
Closed this issue · 2 comments
Pitch
For SSL pretrainings, we're interested in having non raster targets. For example, with the dataset Satlas, we'll have 7 target types (points, polygons, polylines, segmentation, regression, properties, classification).
At the moment, the class OdnDataset
executes the preprocessor, which adds to the output dict only input fields of types raster
or mask
. Depending on this type, the preprocessor calls either apply_to_raster
or apply_to_mask
to open the file that is pointed at with rasterio.
It would be useful if the preprocessor could also add meta data that does not require to be opened with rasterio and that included directly as a field in the csv.
Solution
Add a condition in the forward of the preprocessor.
if value["type"] == "raster":
path = data[value["name"]] if self.root_dir is None else Path(str(self.root_dir)) / data[value["name"]]
band_indices = value["band_indices"] if "band_indices" in value else None
dtype_max = value["dtype_max"] if "dtype_max" in value else None
mean = value["mean"] if "mean" in value else None
std = value["std"] if "std" in value else None
if dtype_max is None and "dtype" in value:
dtype = value["dtype"]
if dtype in DTYPE_MAX.keys():
dtype_max = DTYPE_MAX[dtype]
else:
raise KeyError(f'your dtype {dtype} for key {key} in your input_fields is not compatible')
dtype_max = DTYPE_MAX[InputDType.UINT8.value] if dtype_max is None else dtype_max
output_dict[key] = self.apply_to_raster(path=path,
band_indices=band_indices,
bounds=bounds,
dtype_max=dtype_max,
mean=mean,
std=std)
elif value["type"] == "mask":
path = data[value["name"]] if self.root_dir is None else Path(str(self.root_dir)) / data[value["name"]]
band_indices = value["band_indices"] if "band_indices" in value else None
one_hot_encoding = value["one_hot_encoding"] if "one_hot_encoding" in value else False
output_dict[key] = self.apply_to_mask(path=path,
band_indices=band_indices,
bounds=bounds,
one_hot_encoding=one_hot_encoding)
else:
output_dict[key] = data[value["name"]]
Should be resolved, added an assert issue in unit test_dataloader_factory_by_patch
Update output dict with missing keys and by deleting key from input data:
for key, value in self._input_fields.items():
if value["type"] == "raster":
path = data[value["name"]] if self.root_dir is None else Path(str(self.root_dir)) / data[value["name"]]
band_indices = value["band_indices"] if "band_indices" in value else None
dtype_max = value["dtype_max"] if "dtype_max" in value else None
mean = value["mean"] if "mean" in value else None
std = value["std"] if "std" in value else None
if dtype_max is None and "dtype" in value:
dtype = value["dtype"]
if dtype in DTYPE_MAX.keys():
dtype_max = DTYPE_MAX[dtype]
else:
raise KeyError(f'your dtype {dtype} for key {key} in your input_fields is not compatible')
dtype_max = DTYPE_MAX[InputDType.UINT8.value] if dtype_max is None else dtype_max
output_dict[key] = self.apply_to_raster(path=path,
band_indices=band_indices,
bounds=bounds,
dtype_max=dtype_max,
mean=mean,
std=std)
del data[value["name"]]
if value["type"] == "mask":
path = data[value["name"]] if self.root_dir is None else Path(str(self.root_dir)) / data[value["name"]]
band_indices = value["band_indices"] if "band_indices" in value else None
one_hot_encoding = value["one_hot_encoding"] if "one_hot_encoding" in value else False
output_dict[key] = self.apply_to_mask(path=path,
band_indices=band_indices,
bounds=bounds,
one_hot_encoding=one_hot_encoding)
del data[value["name"]]
if "geometry" in data.keys():
output_dict["geometry"] = np.array(data["geometry"].bounds)
del data["geometry"]
output_dict = {**data, **output_dict}
Warning
In the next future, a custom collater will be needed to handle cases not handled by default pytorch collated, like shapely geometry objects, and so on.