open-mmlab/mmengine

[Feature] Support calculating loss in the validation step

Closed this issue · 2 comments

What is the feature?

We often receive requests from users who want to be able to print the loss during the validation phase, but MMEngine does not support this feature at the moment. If you also have a need for this, feel free to discuss it.

There are two possible solutions for MMEngine to support this feature.

One is to add a LossMetric to the downstream library, the forward method of the downstream library will still return a list of DataElement, and the LossMetric can calculate the loss by the information in the DataElement.

The other is that the forward method of the downstream library returns a dictionary, and the downstream library needs to finish the loss calculation in forward and return it to MMEngine.

https://github.com/open-mmlab/mmpretrain/blob/17a886cb5825cd8c26df4e65f7112d404b99fe12/mmpretrain/models/classifiers/image.py#L87

References

Any other context?

No response

Hi, I checked the current #1503 work.
Except for the ValLoop changes, I have some suggestions about how to return the losses.

MMEngine

To be more general and follow the current implementation style,
I think we should utilize the loss_and_predict function, like the loss_and_predict of the head module, instead of implementing it in predict mode directly.

In the val_step and test_step functions, we could add an optional argument to the function without breaking changes, e.g., loss: bool = False.
(Notice: the loss calculation in the non-training step should always be an optional option)

    def val_step(self, data: Union[tuple, dict, list], loss: bool = False) -> Union[list, Tuple[list, Dict[str, torch.Tensor]]]:
        """Gets the predictions of given data.

        Calls ``self.data_preprocessor(data, False)`` and
        ``self(inputs, data_sample, mode='predict')`` in order. Return the
        predictions which will be passed to evaluator.

        Args:
            data (dict or tuple or list): Data sampled from dataset.

        Returns:
            list: The predictions of given data.
        """
        data = self.data_preprocessor(data, False)
        if loss:
            return self._run_forward(data, mode='loss_and_predict')  # type: ignore
        else:
            return self._run_forward(data, mode='predict')  # type: ignore

    def test_step(self, data: Union[dict, tuple, list], loss: bool = False) -> Union[list, Tuple[list, Dict[str, torch.Tensor]]]:
        """``BaseModel`` implements ``test_step`` the same as ``val_step``.

        Args:
            data (dict or tuple or list): Data sampled from dataset.

        Returns:
            list: The predictions of given data.
        """
        data = self.data_preprocessor(data, False)
        if loss:
            return self._run_forward(data, mode='loss_and_predict')  # type: ignore
        else:
            return self._run_forward(data, mode='predict')  # type: ignore

Downstream libraries

As for the implementation of the loss_and_predict function, because the current implementations of models don't support the calculation of predictions and losses, we could implement it by calling loss and predict to get the results.

Although we need to infer twice to get the results, I think there is no choice to do it due to the backward compatibility.
The model has to add the loss_and_predict model to the forward function by itself. (optional feature)

def loss_and_predict(self,
                batch_inputs: Tensor,
                batch_data_samples: SampleList,
                rescale: bool = True) -> SampleList:
        """Predict results and calculate losses from a batch of inputs and data samples with post-
        processing.

        Args:
            batch_inputs (Tensor): Inputs, has shape (bs, dim, H, W).
            batch_data_samples (List[:obj:`DetDataSample`]): The batch
                data samples. It usually includes information such
                as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
            rescale (bool): Whether to rescale the results.
                Defaults to True.

        Returns:
            list[:obj:`DetDataSample`]: Detection results of the input images.
            Each DetDataSample usually contain 'pred_instances'. And the
            `pred_instances` usually contains following keys.

            - scores (Tensor): Classification scores, has a shape
              (num_instance, )
            - labels (Tensor): Labels of bboxes, has a shape
              (num_instances, ).
            - bboxes (Tensor): Has a shape (num_instances, 4),
              the last dimension 4 arrange as (x1, y1, x2, y2).
            dict: A dictionary of loss components
        """
        preds = self.predict(batch_inputs, batch_data_samples, rescale=rescale)
        losses = self.loss(batch_inputs, batch_data_samples)
        return preds, losses

I think this is a more elegant and ideal solution.

Hi @guyleaf , we have discussed your proposal before. Your solution is elegant but it requires the model forward twice.