Dootmaan/DTFD-MIL.PyTorch

Hi, did you try to visulize the feature embeddings of training/valid dataset?

Opened this issue · 0 comments

Did you try to visulize the feature embeddings of training/valid dataset? I trained on my own wsi dataset(2 cls), including 210 training wsi(84 for negtive and others for positive) and 20 valid wsi, and the val auc reached 0.9, however, the embeddings visulization by TSNE is pretty bad.

I use slide_d_feat as embedding:

def test_attention_DTFD_preFeat_MultipleMean(mDATA_list, classifier, dimReduction, attention, UClassifier, epoch, criterion=None,  params=None, f_log=None, writer=None, numGroup=3, total_instance=3, distill='MaxMinS'):

    .......

    with torch.no_grad():

        numSlides = len(SlideNames)
        numIter = numSlides // params.batch_size_v
        tIDX = list(range(numSlides))
        

        **embeddings=[]**

        for idx in range(numIter):

            ......
            
            for tidx, tfeat in enumerate(batch_feat):
                ......

                for jj in range(params.num_MeanInference):

                   ......
                    for tindex in index_chunk_list:
                        ......

                        if distill == 'MaxMinS':
                            topk_idx_max = sort_idx[:instance_per_group].long()
                            topk_idx_min = sort_idx[-instance_per_group:].long()
                            topk_idx = torch.cat([topk_idx_max, topk_idx_min], dim=0)
                            d_inst_feat = tmidFeat.index_select(dim=0, index=topk_idx)
                            slide_d_feat.append(d_inst_feat)
                        elif distill == 'MaxS':
                            topk_idx_max = sort_idx[:instance_per_group].long()
                            topk_idx = topk_idx_max
                            d_inst_feat = tmidFeat.index_select(dim=0, index=topk_idx)
                            slide_d_feat.append(d_inst_feat)
                        elif distill == 'AFS':
                            slide_d_feat.append(tattFeat_tensor)
                    

                    slide_d_feat = torch.cat(slide_d_feat, dim=0) 


                    **embeddings.append(slide_d_feat)**

                    slide_sub_preds = torch.cat(slide_sub_preds, dim=0)
                    slide_sub_labels = torch.cat(slide_sub_labels, dim=0)
                   
                    ......
        
                    return auc_1, embeddings

And then visualize:

auc_val,embedings_val = test_attention_DTFD_preFeat_MultipleMean(classifier=classifier, dimReduction=dimReduction, attention=attention,UClassifier=attCls, mDATA_list=(SlideNames_val, FeatList_val, Label_val), criterion=None, epoch=None,  params=params, f_log=log_file, writer=None, numGroup=params.numGroup_test, total_instance=params.total_instance_test, distill=params.distill_type)

# visualize
num_groups,embed_dim=embedings_val[0].shape
embedings_val_np=[]
for embed in embedings_val:
    embed=embed.reshape((1,num_groups*embed_dim)).cpu().numpy()
    # embed=np.max(embed.cpu().numpy(),axis=0,keepdims=True)
    embedings_val_np.append(embed)
embedings_val_np=np.concatenate(embedings_val_np,0)
print(embedings_val_np.shape)# 20x2048
from sklearn.manifold import TSNE
val_vis_embed = TSNE(n_components=2, learning_rate='auto', init='random').fit_transform(embedings_val_np)
# print(val_vis_embed)
import matplotlib.pyplot as plt
plt.scatter(val_vis_embed[:,0][:10],val_vis_embed[:,1][:10])
plt.scatter(val_vis_embed[:,0][10:],val_vis_embed[:,1][10:])
plt.show()