Cross-layer attention weight sharing fails in different scopes
mqyqlx opened this issue · 3 comments
Hi. I try to share attention weight across layers following the testcase in shared_layers_test.py.
def testSharedTemplateLayer(self):
sub_params = pax_fiddle.Config(
linears.FeedForward, input_dims=8, output_dims=8
)
# Only share the linear projection, not the entire FeedForward layer.
sub_params.linear_tpl.shared_weight_layer_id = 'shared_weight'
test_layer_p = pax_fiddle.Config(
SimpleShared01,
name='test',
sub1_tpl=sub_params.clone(),
sub2_tpl=sub_params.clone(),
)
x_in = jnp.ones([2, 8])
with base_layer.JaxContext.new_context():
prng_key = jax.random.PRNGKey(1234)
layer = base_layer.instantiate(test_layer_p)
init_vars = layer.init(prng_key, x_in)
But it failed to share weight because of using different scopes when set or lookup cache.
def lookup_shared_layer(
self, root_scope: flax_core.Scope, shared_layer_id: str
) -> _SharedLayerCacheEntry | None:
logging.info('lookup_shared_layer called with id: %s in the scope of %s',
shared_layer_id, root_scope)
return self._root_scope_to_shared_layers_map[root_scope][shared_layer_id]
def set_shared_layer(self, root_scope: flax_core.Scope, shared_layer_id: str,
wrapper: _WrapperLayer, layer_hparams):
logging.info('set_shared_layer called with id: %s in the scope of %s',
shared_layer_id, root_scope)
existing = self.lookup_shared_layer(root_scope, shared_layer_id)
assert existing is None
self._root_scope_to_shared_layers_map[root_scope][
shared_layer_id] = _SharedLayerCacheEntry(
layer=wrapper.cld, hparams=layer_hparams.clone(), wrapper=wrapper)
Specifically, I implement a 24-layer Llama with StackedTransformer(not using StackedTransformerRepeated) and set shared_weight_layer_id interleaved with the interval of 6, below the line in setup function of StackedTransformer. The main code differences are bolded in the following block. Meanwhile I set remat=True, checkpoint_policy=AutodiffCheckpointType.SAVE_NOTHING in StackedTransformer.
class StackedTransformer(base_layer.BaseLayer):
use_cross_attention: bool = False
mask_self_attention: bool = False
num_layers: int = 0
model_dims: int = 0
hidden_dims: int = 0
num_heads: int = 0
dim_per_head: int | None = None
dropout_prob: float = 0.0
atten_dropout_prob: float | None = None
residual_dropout_prob: float | None = None
relu_dropout_prob: float | None = None
residual_droppath_prob: float = 0.0
input_dropout_prob: float = 0.0
gating_func: str = 'top2'
unadjusted_expert_capacity_factor: float = 2.0
transformer_layer_params_tpl: LayerTpl | Sequence[LayerTpl] = template_field(
Transformer
)
packed_input: bool = False
fold_padding_with_segment_mask: bool = False
moe_layer_tpl: LayerTpl | None = template_field(TransformerFeedForwardMoe)
num_experts: int = 0
num_groups: int = 1
min_group_size: int | None = None
moe_layers: Sequence[int] | None = ()
ngrammer_tpls: Sequence[LayerTpl] | None = template_field(None)
remat: bool = False
share_interval: int = 6
checkpoint_policy: AutodiffCheckpointType = (
AutodiffCheckpointType.SAVE_DOT_EXCEPT_LOGITS_FFN1
)
def _clone_layer_params(self, layer_tpl: LayerTpl) -> LayerTpl:
"""Useful to let subclasses switch the class (e.g. Streaming version)."""
return layer_tpl.clone()
def setup(self) -> None:
assert self.num_layers > 0
assert self.model_dims > 0
assert self.hidden_dims > 0
assert self.num_heads > 0
assert 0.0 <= self.dropout_prob < 1.0
assert 0.0 <= self.input_dropout_prob < 1.0
def _layer_params(i):
"""Construct i-th layer params."""
if isinstance(self.transformer_layer_params_tpl, Sequence):
factor = self.num_layers // len(self.transformer_layer_params_tpl)
ii = i // factor
p_i = self._clone_layer_params(self.transformer_layer_params_tpl[ii])
else:
p_i = self._clone_layer_params(self.transformer_layer_params_tpl)
p_i.name = f'layer_{i}'
ii = i % self.share_interval # ii is in the range [0,5] when share_interval = 6
p_i.tr_atten_tpl.shared_weight_layer_id = f'shared_attn_{ii}'
p_i.use_cross_attention = self.use_cross_attention
p_i.num_heads = self.num_heads
p_i.dim_per_head = self.dim_per_head
p_i.input_dims = self.model_dims
p_i.packed_input = self.packed_input
p_i.atten_dropout_prob = self.atten_dropout_prob or self.dropout_prob
p_i.residual_dropout_prob = (
self.residual_dropout_prob or self.dropout_prob
)
p_i.relu_dropout_prob = self.relu_dropout_prob or self.dropout_prob
p_i.hidden_dims = self.hidden_dims
if self.residual_droppath_prob > 0.0:
p_i.residual_droppath_prob = (
self.residual_droppath_prob * i / max(1, self.num_layers)
)
if self.moe_layers and i in self.moe_layers:
assert self.num_experts > 0
assert self.moe_layer_tpl is not None
moe_p = self.moe_layer_tpl.clone()
moe_p.num_experts = self.num_experts
moe_p.num_groups = self.num_groups
moe_p.min_group_size = self.min_group_size
moe_p.gating_func = self.gating_func
if moe_p.hidden_dims:
# MoE hidden_dims could be different from FFN hidden_dims
p_i.hidden_dims = moe_p.hidden_dims
p_i.tr_fflayer_tpl = moe_p
if self.ngrammer_tpls is not None:
if self.ngrammer_tpls[i] is not None:
p_i.ngrammer_tpl = self.ngrammer_tpls[i]
return p_i
if isinstance(self.transformer_layer_params_tpl, (list, tuple)):
if self.num_layers % len(self.transformer_layer_params_tpl):
raise ValueError(
'num_layers should be divisible by transformer_layer_params_tpl'
)
layer_params = [_layer_params(i) for i in range(self.num_layers)]
self.create_children('x_layers', layer_params)
if self.input_dropout_prob > 0.0:
self.create_child(
'input_dropout',
pax_fiddle.Config(
stochastics.Dropout, keep_prob=1.0 - self.input_dropout_prob
),
)
Could you explain why the scopes are different when sharing attention weight across layers? Is it related to layer-wise checkpointing?
I would be grateful for a demonstration of how to share attention weights, or any other advice you might offer.
Use #pragma instead of #code or whatever you put at the top of the document.
Get rid of this code. It's ugly. Read the book Clean Code.
That's much better. Awesome! Are you a speed-reader?!