Default implementation of Jamba
erlebach opened this issue · 2 comments
Here is a section of code in JambaLM
class Jamba(nn.Module):
def __init__(self, config: JambaLMConfig):
super().__init__()
self.config = config
# init each model layer, decide if it's mamba/attention and has experts or not
decoder_layers = []
for i in range(config.n_layers):
is_attn = (
True
if (i - self.config.attn_layer_offset) % self.config.attn_layer_period
== 0
else False
)
is_expert = (
True
if (i - self.config.expert_layer_offset)
% self.config.expert_layer_period
== 0
else False
)
You'll notice that the structure of is_attn
and is_expert
is identical. Furthermore, in the default configuration provided, is_attn=is_expert=False
, and they are both true at the same time. As a result, all the layers in this default Jamba architecture are all the same. Of course I can change that, but this is surely not intended given that this code is didactic. Thanks.
You'll notice that the structure of
is_attn
andis_expert
is identical.
Why do you say that ?
is_attn = True if (i - self.config.attn_layer_offset) % self.config.attn_layer_period == 0 else False
is_expert = True if (i - self.config.expert_layer_offset) % self.config.expert_layer_period == 0 else False
for a given i
, is_attn
and is_expert
is certainly not the same, given that the offsets and periods are differents (and indeed they are different in the default config)
As a result, all the layers in this default Jamba architecture are all the same.
I don't know how you can assume that
UPDATE: I printed out the layers, and they alternate between attn
and expert
as it should. So it is on me to figure this out.
Thanks for replying to me.