IGNF/odeon

Add meta data input fields

Closed this issue · 2 comments

Pitch

For SSL pretrainings, we're interested in having non raster targets. For example, with the dataset Satlas, we'll have 7 target types (points, polygons, polylines, segmentation, regression, properties, classification).

At the moment, the class OdnDataset executes the preprocessor, which adds to the output dict only input fields of types raster or mask. Depending on this type, the preprocessor calls either apply_to_raster or apply_to_mask to open the file that is pointed at with rasterio.

It would be useful if the preprocessor could also add meta data that does not require to be opened with rasterio and that included directly as a field in the csv.

Solution

Add a condition in the forward of the preprocessor.

if value["type"] == "raster":
      path = data[value["name"]] if self.root_dir is None else Path(str(self.root_dir)) / data[value["name"]]
      band_indices = value["band_indices"] if "band_indices" in value else None
      dtype_max = value["dtype_max"] if "dtype_max" in value else None
      mean = value["mean"] if "mean" in value else None
      std = value["std"] if "std" in value else None
      if dtype_max is None and "dtype" in value:
          dtype = value["dtype"]
          if dtype in DTYPE_MAX.keys():
              dtype_max = DTYPE_MAX[dtype]
          else:
              raise KeyError(f'your dtype  {dtype} for key {key} in your input_fields is not compatible')
      dtype_max = DTYPE_MAX[InputDType.UINT8.value] if dtype_max is None else dtype_max
      output_dict[key] = self.apply_to_raster(path=path,
                                              band_indices=band_indices,
                                              bounds=bounds,
                                              dtype_max=dtype_max,
                                              mean=mean,
                                              std=std)
  elif value["type"] == "mask":
      path = data[value["name"]] if self.root_dir is None else Path(str(self.root_dir)) / data[value["name"]]
      band_indices = value["band_indices"] if "band_indices" in value else None
      one_hot_encoding = value["one_hot_encoding"] if "one_hot_encoding" in value else False
      output_dict[key] = self.apply_to_mask(path=path,
                                            band_indices=band_indices,
                                            bounds=bounds,
                                            one_hot_encoding=one_hot_encoding)
  else:
      output_dict[key] = data[value["name"]]

PR fix: #26

Should be resolved, added an assert issue in unit test_dataloader_factory_by_patch
Update output dict with missing keys and by deleting key from input data:

        for key, value in self._input_fields.items():
            if value["type"] == "raster":
                path = data[value["name"]] if self.root_dir is None else Path(str(self.root_dir)) / data[value["name"]]
                band_indices = value["band_indices"] if "band_indices" in value else None
                dtype_max = value["dtype_max"] if "dtype_max" in value else None
                mean = value["mean"] if "mean" in value else None
                std = value["std"] if "std" in value else None
                if dtype_max is None and "dtype" in value:
                    dtype = value["dtype"]
                    if dtype in DTYPE_MAX.keys():
                        dtype_max = DTYPE_MAX[dtype]
                    else:
                        raise KeyError(f'your dtype  {dtype} for key {key} in your input_fields is not compatible')
                dtype_max = DTYPE_MAX[InputDType.UINT8.value] if dtype_max is None else dtype_max
                output_dict[key] = self.apply_to_raster(path=path,
                                                        band_indices=band_indices,
                                                        bounds=bounds,
                                                        dtype_max=dtype_max,
                                                        mean=mean,
                                                        std=std)
                del data[value["name"]]
            if value["type"] == "mask":
                path = data[value["name"]] if self.root_dir is None else Path(str(self.root_dir)) / data[value["name"]]
                band_indices = value["band_indices"] if "band_indices" in value else None
                one_hot_encoding = value["one_hot_encoding"] if "one_hot_encoding" in value else False
                output_dict[key] = self.apply_to_mask(path=path,
                                                      band_indices=band_indices,
                                                      bounds=bounds,
                                                      one_hot_encoding=one_hot_encoding)
                del data[value["name"]]
        if "geometry" in data.keys():
            output_dict["geometry"] = np.array(data["geometry"].bounds)
            del data["geometry"]
        output_dict = {**data, **output_dict}

Warning

In the next future, a custom collater will be needed to handle cases not handled by default pytorch collated, like shapely geometry objects, and so on.