Here is the official PyTorch implementation of Certifiable Robust Multi-modal Training (CRMT) proposed in ''Quantifying and Enhancing Multi-modal Robustness with Modality Preference'', which is a flexible training procedure to improve the robustness of multimodal learning. Please refer to our ICLR 2024 paper for more details.
Paper Title: "Quantifying and Enhancing Multi-modal Robustness with Modality Preference"
Authors: Zequn Yang, Yake Wei, Ce Liang and Di Hu
Accepted by: The Twelfth International Conference on Learning Representations (ICLR 2024)
[arXiv]
We observe that when encountering attack on different modality, multi-modal model can be vulnerable to certain modality. The following figure presents the how different multi-modal model performs in each uni-modal attack. As shown in the following figure, attack on the audio modality (#a, dot dash line) lead to more performance discrepancy than visual modality (#v, full line). That indicates that the preferred modality, audio, is not always the robust one, and thus should be paid attention on in multi-modal learning.
Here, we provide theoretical support to explain this phenomenon. First, we derive the certified bound for multi-modal model. Subsequently, we figure out that both uni-modal representation margin and integration of modalities are two essential components of the certified robustness of multi-modal model. And these two components is easily limited by modality preference. Further, we also analyze the certified robustness of each modality, which indicate that multi-modal model can prefer vulnerable modality, thus making attacks on this preferred modality more effective, explaining the observation above. The detailed analysis can be found in our paper.
Supported by these analysis before, we seek for stably enhancing the certified robustness. However, directly regulations of these two compoenets are intricately linked with the last linear classifier
With this newly proposed framework, the pipeline of our certified robust multi-modal training algorithm consists of two sub-process:
- Optimize with cross-entropy loss and margin regularization with term
$\rho$ :
- Fix
$\tilde{W}^{(m)}, \phi^{(m)}$ , update${a}^{(m)}$ to approach higher certified robustness:$\min_{{a}^{(m)}}~~ L_2 = -\frac{1}{N} \sum_{i=1}^N r({x}_i),$ where$r({x})$ is the certified lower bound.
- Python 3.8
pip install -r requirements.txt
You can simply run the demo of CRMT_JT using:
python main.py methods=CRMT_JT methods.gamma=1.0
You can adjust the algorithm's detailed setting by modifying parameters such as
The original datasets we used can be found in: Kinetics-Sounds. UCF101, VGGSound,
Our proposed Certifiable Robust Multi-modal Training (CRMT) can be applied on three training strategy, denoted as Certifiable Robust Multi-modal Training with Joint Training (CRMT-JT), CRMT with Adversarial Training (CRMT-AT), and CRMT with Mixup (CRMT-Mix).
As a flexible training procedure, our proposed CRMT can be easily integrated into existing multimodal learning frameworks. Since the detailed implementation of multimodal learning models varies, we provide a simple demo to illustrate how to integrate CRMT into a multimodal joint learning, that is CRMT_JT. We dispaly the training of the step one as following:
---in training step 1 ---
# out_v, out_a are the output of each uni-modality model, sharing the same shape (batch_size * label_num) with out.
out_v, out_a, out = model(visual, audio)
outs = [out_v, out_a]
rows = torch.arange(batch_size)
exp_margin_loss = 0.0
for modality in range(num_modal):
out_cur = outs[modality]
exp_margin_loss = exp_margin_loss + (torch.sum(torch.exp(out_cur), dim = 1) * torch.exp(-out_cur[rows, labels]) - 1)
loss_margin = torch.mean(torch.log(exp_margin_loss + 1e-5)) * rho
loss = criterion(out, labels) + loss_margin
loss.backward()
optimizer.step()
---continue for next training step---
Note that for CRMT_AT and CRMT_Mix, the training step 1 is not the same as the above. The main difference of CRMT_AT and CRMT_Mix lies in the input data, which is the adversarial samples and mixup samples, respectively. Furthermore, since the mixup could mix samples with two different labels, the loss_margin should be modified to enlarge the margin between these two labels and other labels. Hence we use a indicator vector to distinguish whether the two labels are the same or not shown in the following code:
data, targets_a, targets_b, lam = mixup_data([visual, audio], labels, alpha=1.0)
outs = [out_v, out_a]
rows = torch.arange(batch_size)
same_target = (targets_a != targets_b)
for modality in range(num_modal):
numerator = torch.sum(torch.exp(outs[modality]), dim = 1) \
- outs[modality][rows, targets_a] - outs[modality][rows, targets_b] * same_target
The detailed implementation can be found in our code. Moreover, the training step 2 is the same for all the three training strategies.The second step of our training procedure is to optimize the integration weight to approach higher certified robustness. The detailed implementation can be found in our code. It contains two sub-step:
- Get the idx of the largest output (except for the labels), which is realized by a mask operation.
- Calculate the certified lower bound and update the integration weight to approach higher certified robustness.
If you find this work useful, please consider citing it.
@inproceedings{yang2024Quantifying,
title={Quantifying and Enhancing Multi-modal Robustness with Modality Preference},
author={Yang, Zequn and Wei, Yake and Liang, Ce and Hu, Di},
booktitle={The Twelfth International Conference on Learning Representations},
year={2024}
}
This research was supported by National Natural Science Foundation of China (NO.62106272), the Young Elite Scientists Sponsorship Program by CAST (2021QNRC001), and Public Computing Cloud, Renmin University of China.
If you have any detailed questions or suggestions, you can email us: zqyang@ruc.edu.cn