georgian-io/Multimodal-Toolkit

Support for imbalanced classification

curiousRed opened this issue · 2 comments

Thank you for the library, it made the coding for the project I am currently working much easier. I was wondering if it is possible to set the weights for the imbalanced classification? Thank you!

For those wondering, you can inherit the base trainer class to adjust the loss function for oversampling (here I'm doing a binary classification) like so:

           import torch
           from torch.utils.data import WeightedRandomSampler, DataLoader
            from transformers import Trainer
            class CW_Trainer(Trainer):
                def get_train_dataloader(self):
                    """
                    Returns the training :class:`~torch.utils.data.DataLoader`.

                    Will use no sampler if :obj:`self.train_dataset` does not implement :obj:`__len__`, a random sampler (adapted
                    to distributed training if necessary) otherwise.

                    Subclass and override this method if you want to inject some custom behavior.
                    """
                    if self.train_dataset is None:
                        raise ValueError("Trainer: training requires a train_dataset.")
                    # train_sampler = self._get_train_sampler()
                    target = np.array(train_dataset.labels)
                    print('target train 0/1: {}/{}'.format(
                        len(np.where(target == 0)[0]), len(np.where(target == 1)[0])))
                    class_sample_count = np.array([len(np.where(target == t)[0]) for t in np.unique(target)])
                    weight = 1. / class_sample_count
                    target = target.astype("int64")
                    samples_weight = np.array([weight[t] for t in target])
                    samples_weight = torch.from_numpy(samples_weight)
                    samples_weight = samples_weight.double()
                    train_sampler = WeightedRandomSampler(samples_weight, len(samples_weight))


                    return DataLoader(
                    self.train_dataset,
                    batch_size=self.args.train_batch_size,
                    sampler=train_sampler,
                    collate_fn=self.data_collator,
                    drop_last=self.args.dataloader_drop_last,
                    )
            trainer = CW_Trainer(
                model,

Closing as the solution is given above.