Valentyn1997/CausalTransformer

Question about src/models/utils_transformer.py

Closed this issue · 2 comments

Hi, thank you so much for your papers and codes. I think this is great work.
I am trying to apply this model to real-world data.
May I ask some questions about the TransformerMultiInputBlock class in src/models/utils_transformer.py ?

I found that you are calculating the self-attention of x_v with self.self_attention_o layer, which is used to calculate the self-attention of x_o. However, I think it should be calculated through self.self_attention_v layer. I wonder if this issue may have some impact on the model's results. If I am missing something, it would be great to have your comments.

I also found similar issues in the cross-attention calculations, including x_tv_ , x_ov_ , x_vt_ and x_vo_ . Each of them uses either self.self_attention_ot or self.self_attention_to layer, which is used for calculating cross-attentions between x_o and x_t.

Here is the corresponding code.
`if self.n_inputs == 2:
out_t = self.feed_forwards[0](x_to_ + x_s)
out_o = self.feed_forwards[1](x_ot_ + x_s)

        return out_t, out_o

    else:
        self_att_mask_v = active_entries_vitals.repeat(1, 1, x_v.size(1)).unsqueeze(1)
        cross_att_mask_ot_v = (active_entries_vitals.squeeze(-1).unsqueeze(1) * active_entries_treat_outcomes).unsqueeze(1)
        cross_att_mask_v_ot = (active_entries_treat_outcomes.squeeze(-1).unsqueeze(1) * active_entries_vitals).unsqueeze(1)

        x_tv_ = self.cross_attention_to(x_t_, x_v, x_v, cross_att_mask_ot_v, True) if not self.disable_cross_attention \
            and self.isolate_subnetwork != 't' and self.isolate_subnetwork != 'v' else 0.0
        x_ov_ = self.cross_attention_to(x_o_, x_v, x_v, cross_att_mask_ot_v, True) if not self.disable_cross_attention \
            and self.isolate_subnetwork != 'o' and self.isolate_subnetwork != 'v' else 0.0

        x_v_ = self.self_attention_o(x_v, x_v, x_v, self_att_mask_v, True)
        x_vt_ = self.cross_attention_ot(x_v_, x_t, x_t, cross_att_mask_v_ot, True) if not self.disable_cross_attention \
            and self.isolate_subnetwork != 'v' and self.isolate_subnetwork != 't' else x_v_
        x_vo_ = self.cross_attention_ot(x_v_, x_o, x_o, cross_att_mask_v_ot, True) if not self.disable_cross_attention \
            and self.isolate_subnetwork != 'v' and self.isolate_subnetwork != 'o' else 0.0

        out_t = self.feed_forwards[0](x_to_ + x_tv_ + x_s)
        out_o = self.feed_forwards[1](x_ot_ + x_ov_ + x_s)
        out_v = self.feed_forwards[2](x_vt_ + x_vo_ + x_s)`

Thank you

Hi!

Thanks for pointing this out, this is indeed a huge bug. I pushed an update to the code, and I will re-run the semi-synthetic and real-world experiments. Nevertheless, I don't expect much difference in the performance, as transformers are pretty over-parametrized models.

Have a nice day!