dohlee/chromoformer

Can't replicate regression performance

Closed this issue · 6 comments

Hey!

Well done again on a very interesting model!

I have been trying to replicate the chromoformer with the regression head's training but am noticing far worse performance than that you got in your manuscript - I am getting a correlation score of ~0.4 which is much lower than your ~0.8 score. This is roughly the same for both the training and validation cohort just slightly less for the validation than training. Clearly, there is some issue in my approach so I just wanted to confirm somethings I did (mostly since the model in the Github repo is set up for classification rather than regression so I'm guessing there are necessary changes I'm missing).

I have listed out any changes I made from what is in this Github repo below. Note everything else is the exact same as you did, including the training script and the parameter choices for the model (set the same as the model default and in the config file) -

epochs=10
batch_size=64
lr = 3e-5
gamma = 0.87
i_max = 8
embed_n_layers = 1
embed_n_heads = 2
embed_d_model = 128
embed_d_ff = 128
pw_int_n_layers = 2
pw_int_n_heads = 2
pw_int_d_model = 128
pw_int_d_ff = 256
reg_n_layers = 6
reg_n_heads = 8
reg_d_model = 256
reg_d_ff = 256
head_n_feats = 128

Changes:

  • Model changes - Updated the model to have one output with linear activation for the regression task:
self.fc_head = nn.Sequential(
            nn.Linear(embed_d_model * 3, head_n_feats),
            nn.ReLU(),
            nn.Linear(head_n_feats, 1),#2),
        )
  • Data loader changes - Updated Roadmap3D to load the RPKM expression value rather than the label:
meta.expression = np.log2(meta.expression+5e-324)
self.ensg2exp = {r.gene_id:r.expression for r in self.meta.to_records()}

I then returned log2RPKM from the dataloader and used this as the Y for my regression model with MSE as the loss function:

criterion = nn.MSELoss()
loss = criterion(out, d['log2RPKM'].float().unsqueeze(1))

I then monitor the pearsonR in the training script:

#using torch metrics functional pearson R function
from torchmetrics.functional import pearson_corrcoef
...
        val_out.append(out.cpu())
        val_label.append(d['log2RPKM'].unsqueeze(1).cpu())

val_out = torch.cat(val_out)
val_label = torch.cat(val_label)

val_loss = criterion(val_out, val_label)

# Metrics.
val_mse = metrics.mean_squared_error(val_label, val_out)
#since shape is (X,1) convert to (X)
val_corr = pearson_corrcoef(val_label[:,0], val_out[:,0])

It is also worth noting that I tested a simple CNN architecture on this same data and got a higher correlation score of ~0.55.

Can you spot anything I should change about the model to replicate the performance in the manuscript?

Cheers,
Alan.

Hey! If it's easier, can you just share the files that differ from those in the repo that were used to train the regression head chromoformer model from your paper? I can work on comparing from there. Thanks, Alan.

Sorry for the late reply. I've been on a military training for last 3 weeks, so I could not check any emails or issues.

I'll update codes and scripts for reproducing Chromoformer-reg shortly. My apologies again.

Best regards,
Dohoon

No problem at all, thanks for the update! Could you maybe update this issues when you get a chance to add it so I can test?

Cheers,
Alan.

@Al-Murphy Would you like to try transforming expression value with log2(x+1) transformation instead of log2(x+5e-324)? I think that's the only difference for now. Actually log2(x+5e-324) results in somewhat skewed distribution of gene expression values, so I prefer using log2(x+1) in general. Please refer to the transformed distribution below (using demo data).

image

Please let me know whether it becomes reproducible using log2(x+1) transformation. I'll be cleaning up the codes for Chromoformer-reg and Chromoformer-diff in the meantime.

Thank you again for your interest in our model!

Best regards,
Dohoon

Hey @dohlee thanks for the suggestion, I didn't realise adding 5e-324 would skew the transformed expression values so much! I have updated the code as suggested however, although this did change performance it is still far lower than expected at 0.45.

Thanks,
Alan.

@dohlee apologies it is replicating performance now, there was an issue with my data loader! I appreciate all the help!

Alan.