Zhenhong Zhou1, Haiyang Yu1, Xinghua Zhang1, Rongwu Xu3, Fei Huang1, Kun Wang2, Yang Liu4, Junfeng Fang2,, Yongbin Li1,
1Alibaba Group, 2University of Science and Technology of China,
3Tsinghua University, 4Nanyang Technological University
Large language models (LLMs) achieve state-of-the-art performance on multiple language tasks, yet their safety guardrails can be circumvented, leading to harmful generations. In light of this, recent research on safety mechanisms has emerged, revealing that when safety representations or component are suppressed, the safety capability of LLMs are compromised. However, existing research tends to overlook the safety impact of multi-head attention mechanisms, despite their crucial role in various model functionalities. Hence, in this paper, we aim to explore the connection between standard attention mechanisms and safety capability to fill this gap in the safety-related mechanistic interpretability. We propose a novel metric which tailored for multi-head attention, the Safety Head ImPortant Score (Ships), to assess the individual heads' contributions to model safety. Based on this, we generalize Ships to the dataset level and further introduce the Safety Attention Head AttRibution Algorithm (Sahara) to attribute the critical safety attention heads inside the model. Our findings show that the special attention head has a significant impact on safety. Ablating a single safety head allows aligned model (
pip install requirements.txt
This part corresponds to Section 3 of our paper, and the main coding in lib/SHIPS
.
In lib/SHIPS/get_ships.py
, we define a SHIPS
class, the class can calculate which head is important for every (harmful) query based on the language model.
For the class, the primary hyperparameters are mask_config
which included mask_qkv
, scale_factor
and mask_type
.
mask_qkv
specifies to modify the Q, K or V matrix; scale_factor
specifies the modification coefficient mask_type
specifies the modification type (mean or use
For a mini demo, see Ships_quick_start.ipynb
.
This part corresponds to Section 4 of our paper, and the main coding in lib/Sahara
In lib/Sahara/attribution.py
, the primary function is safety_head_attribution
.
For the function, the hyperparameters are search_step
which specifies how many collaborative safety heads to search; mask_qkv
specifies to modify the Q, K or V matrix; scale_factor
specifies the modification coefficient mask_type
specifies the modification type (mean or use
For a mini demo, see Generalized_Ships.ipynb
.
In the mini demo, the primary hyperparameters include
By Ships or Generalized Ships, we can attribute safety heads. Then, we can ablate safety head following Surgery.ipynb
to obtain an ablated LLMs. The weights also can be load from transformers.AutoModel
instead of custommodel
.