liupei101/AdvMIL

High c-index for patchGCN on LUAD dataset

Raymvp opened this issue · 11 comments

Raymvp commented

Dear Dr. Liu,
I came across a unique situation while working on a replication of patchGCN, which I noticed you utilized in the experimental section of your recent paper.

In the process of my experiments, I observed a relatively high c-index for the validation set on the LUAD dataset, around 0.75, while the c-index during the training phase was around 0.68. This seemed unusual as the original paper reported a c-index of only 0.585.

Did you encounter any anomalies or unusually high c-index values in your implementations on patchGCN? Any insights you could provide would be immensely helpful and appreciated.

Thank you for your time

Xin Liu

Our experimental results are roughly the same as those reported in the paper of PatchGCN. If there is a significant gap between your results and the original ones, you could check the following points, in my opinion,

  1. make sure that you have followed the same experimental setting as PatchGCN, e.g., using 5-fold cross-validation, etc.
  2. from my observations, a large C-Index could be obtained if you set a relatively-small time bin (that is used to uniformly divide survival time into several bins).

Hope these could help you. Good luck!

Raymvp commented

The default time bin is 4, but I've tried setting it to 8 and 10, and the results didn't show significant differences. Could you please tell me what's the time bin you've set? If it's possible, could you share your patchGCN code with me? Thank you.

Our setting is also 4. Our codes of PatchGCN are available at here.

It is highly recommended that you may ask for help at the official site of PatchGCN. You would get more useful suggestions and insights from the authors of PatchGCN for the questions you raised here.

Good luck!

Raymvp commented

刘博士你好!首先感谢您之前的帮助!我已经在PatchGCN页面的issue问过了,没有得到有效的回答。于是尝试用您的代码来跑。 现在我想知道如何能成功用您的代码跑通patchGCN。有几个缺失的参数,我设置为pdh_dims: 384-1
mlp_hops: 3
mlp_norm: True
mlp_dropout: 0.25
opt_net: adam
opt_net_weight_decay: 1e-4
opt_net_lr: 0.0002,但是并不是很有效,我用4090显卡跑,batchsize为1的情况下,报错RuntimeError: CUDA out of memory. Tried to allocate 1.11 GiB (GPU 5; 23.65 GiB total capacity; 21.38 GiB already allocated; 964.56 MiB free; 21.61 GiB reserved in total by PyTorc
h) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF 。 我想请问,patchGCN的yaml应该如何设置?

我使用patchgcn也出现过类似问题。据之前的实验经历,建议您检查这样几个实验参数:

  1. 降低每个batch 内patch的总数,一般1500左右是没问题的。
  2. 设置GCN层数为1。
  3. 降低embedding 的维度,比如把384改为128,或256。
Raymvp commented

我在GCN层数和原始论文一样,bp_every_batch: 3的情况下,用您的代码依然跑出了超出原文与您论文复现中的效果。
image
image
是否可能是由于 代码依赖包 的更新,导致的效果大幅上升?

数据的准备,比如patch划分的设置,以及patch特征提取的设置等,也可能会影响最终的预测性能。您能提供一下上面运行结果相关的完整的yaml文件配置吗?

Raymvp commented

好的。task: surv_nll # cont_gansurv
seed: 42
cuda_id: 5

wandb_dir: /home/ubuntu/AdvMIL-main # path to this repo
wandb_prj: patchGCN_defult # wandb project name
save_path: ./results-adv/brca-patchGCN # path to save log files during training and testing

data

dataset: BRCA
path_patch: /home/ubuntu/dataset/TCGA_BRCA_features/extracted_feature/pt_files # path to patch features, for patch-based models
path_graph: /home/ubuntu/dataset/TCGA_BRCA_features/tcga_brca_20x_features/graph_euclidean_files # path to WSI graphs, for graph-based models
path_cluster: /data/nlst/processed/patch-l1-cluster8-ids # path to patch clusters, for cluster-based models
path_coordx5: null
path_label: ./table/tcga_brca_path_full.csv # path to the csv table with patient_id, pathology_id, t, e
feat_format: pt
time_format: ratio
time_bins: 4
data_split_path: ./data_split/tcga_brca-fold{}.npz # path to data split
data_split_seed: [0, 1, 2, 3, 4] # fold identifiers, used to fill the placeholder in data_split_path
save_prediction: True
train_sampling: null # w/o data sampling

Backbone setting of MIL encoder

bcb_mode: graph # choose patch, cluster, or graph
bcb_dims: 1024-128-128 # the dims from input dim -> hidden dim -> embedding dim
#以下部分自己设置
pdh_dims: 128-4
mlp_hops: 1
mlp_norm: True
mlp_dropout: 0.25
opt_net: adam
opt_net_weight_decay: 0.00001
opt_net_lr: 0.0002

cuda_id: 5

Generator setting (regarding the end part)

gen_dims: 128-1 # embedding dim -> out dim
gen_noi_noise: 0-1 # noise setting, 0-1 / 1-0 / 1-1
gen_noi_noise_dist: uniform # noise type, gaussian / uniform
gen_noi_hops: 1
gen_norm: False
gen_dropout: 0.6
gen_out_scale: sigmoid

Discriminator

disc_type: prj # how to fuse X and t: cat (vector concatenation) / prj (vector projection)
disc_netx_in_dim: 1024 # input dim of X
disc_netx_out_dim: 128 # out dim of X
disc_netx_ksize: 1
disc_netx_backbone: avgpool
disc_netx_dropout: 0.25
disc_nety_in_dim: 1 # input dim of t
disc_nety_hid_dims: 64-128 # hidden dim of t
disc_nety_norm: False
disc_nety_dropout: 0.0
disc_prj_path: x
disc_prj_iprd: bag # choose bag (regular projection) / instance (RLIP)

loss for all

loss_gan_coef: 0.004 # coefficient of GANLoss
loss_netD: bce # choose bce / hinge / wasserstein
loss_regl1_coef: 0.00001 # coefficient of L1 Regularization

loss for discrete model

loss_mle_alpha: 0.0

loss for continuous model

loss_recon_norm: l1 # l1/l2
loss_recon_alpha: 0.0
loss_recon_gamma: 0.0

Optimizer

opt_netG: adam
opt_netG_lr: 0.00008 # learning rate of generator
opt_netG_weight_decay: 0.0005
opt_netD_lr: 0.00008 # learning rate of discriminator

#training
epochs: 20 # epoch numbers
batch_size: 1
bp_every_batch: 3
num_workers: 8 # work numbers for loading WSI features
es_patience: 3 # es: early stopping
es_warmup: 5
es_verbose: True
es_start_epoch: 0
gen_updates: 1 # 1/2
monitor_metrics: loss # metrics on validation set for early stopping

test

times_test_sample: 30 # sampling times when predicting survival from each WSI.
log_plot: False

Raymvp commented

这是我的环境所用的包
name: clam
channels:

  • pytorch
  • conda-forge
  • https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/
  • https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/
  • defaults
    dependencies:
  • _libgcc_mutex=0.1=conda_forge
  • _openmp_mutex=4.5=2_gnu
  • backcall=0.2.0=pyh9f0ad1d_0
  • backports=1.0=pyhd8ed1ab_3
  • backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0
  • blas=1.0=mkl
  • brotlipy=0.7.0=py37h27cfd23_1003
  • bzip2=1.0.8=h7b6447c_0
  • ca-certificates=2023.5.7=hbcca054_0
  • certifi=2023.5.7=pyhd8ed1ab_0
  • cffi=1.15.1=py37h5eee18b_3
  • charset-normalizer=2.0.4=pyhd3eb1b0_0
  • cryptography=39.0.1=py37h9ce1e76_0
  • cudatoolkit=11.3.1=h2bc3f7f_2
  • decorator=5.1.1=pyhd8ed1ab_0
  • entrypoints=0.4=pyhd8ed1ab_0
  • ffmpeg=4.3=hf484d3e_0
  • freetype=2.12.1=h4a9f257_0
  • giflib=5.2.1=h5eee18b_3
  • gmp=6.2.1=h295c915_3
  • gnutls=3.6.15=he1e5248_0
  • idna=3.4=py37h06a4308_0
  • intel-openmp=2021.4.0=h06a4308_3561
  • ipython=7.33.0=py37h89c1867_0
  • ipython_genutils=0.2.0=py_1
  • jedi=0.18.2=pyhd8ed1ab_0
  • jpeg=9e=h5eee18b_1
  • lame=3.100=h7b6447c_0
  • lcms2=2.12=h3be6417_0
  • ld_impl_linux-64=2.40=h41732ed_0
  • lerc=3.0=h295c915_0
  • libdeflate=1.17=h5eee18b_0
  • libffi=3.4.2=h7f98852_5
  • libgcc-ng=13.1.0=he5830b7_0
  • libgomp=13.1.0=he5830b7_0
  • libiconv=1.16=h7f8727e_2
  • libidn2=2.3.2=h7f8727e_0
  • libpng=1.6.39=h5eee18b_0
  • libsodium=1.0.18=h36c2ea0_1
  • libsqlite=3.42.0=h2797004_0
  • libstdcxx-ng=13.1.0=hfd8a6a1_0
  • libtasn1=4.16.0=h27cfd23_0
  • libtiff=4.5.0=h6a678d5_2
  • libunistring=0.9.10=h27cfd23_0
  • libwebp=1.2.4=h11a3e52_1
  • libwebp-base=1.2.4=h5eee18b_1
  • libzlib=1.2.13=hd590300_5
  • lz4-c=1.9.4=h6a678d5_0
  • matplotlib-inline=0.1.6=pyhd8ed1ab_0
  • mkl=2021.4.0=h06a4308_640
  • mkl-service=2.4.0=py37h7f8727e_0
  • mkl_fft=1.3.1=py37hd3c417c_0
  • mkl_random=1.2.2=py37h51133e4_0
  • ncurses=6.4=hcb278e6_0
  • nettle=3.7.3=hbbd107a_1
  • numpy=1.21.5=py37h6c91a56_3
  • numpy-base=1.21.5=py37ha15fc14_3
  • openh264=2.1.1=h4ff587b_0
  • openssl=1.1.1u=hd590300_0
  • parso=0.8.3=pyhd8ed1ab_0
  • pexpect=4.8.0=pyh1a96a4e_2
  • pickleshare=0.7.5=py_1003
  • pillow=9.4.0=py37h6a678d5_0
  • pip=23.1.2=pyhd8ed1ab_0
  • prompt-toolkit=3.0.38=pyha770c72_0
  • ptyprocess=0.7.0=pyhd3deb0d_0
  • pycparser=2.21=pyhd3eb1b0_0
  • pygments=2.15.1=pyhd8ed1ab_0
  • pyopenssl=23.0.0=py37h06a4308_0
  • pysocks=1.7.1=py37_1
  • python=3.7.16=h7a1cb2a_0
  • python-dateutil=2.8.2=pyhd8ed1ab_0
  • python_abi=3.7=2_cp37m
  • pytorch=1.12.0=py3.7_cuda11.3_cudnn8.3.2_0
  • pytorch-mutex=1.0=cuda
  • readline=8.2=h8228510_1
  • requests=2.28.1=py37h06a4308_0
  • six=1.16.0=pyh6c4a22f_0
  • sqlite=3.42.0=h2c6b66d_0
  • tk=8.6.12=h27826a3_0
  • torchaudio=0.12.0=py37_cu113
  • torchvision=0.13.0=py37_cu113
  • tornado=6.2=py37h540881e_0
  • traitlets=5.9.0=pyhd8ed1ab_0
  • typing_extensions=4.6.3=pyha770c72_0
  • urllib3=1.26.14=py37h06a4308_0
  • wcwidth=0.2.6=pyhd8ed1ab_0
  • wheel=0.40.0=pyhd8ed1ab_0
  • xz=5.4.2=h5eee18b_0
  • zeromq=4.3.4=h9c3ff4c_1
  • zlib=1.2.13=hd590300_5
  • zstd=1.5.4=hc292b87_0
  • pip:
    • absl-py==1.4.0
    • anyio==3.7.0
    • appdirs==1.4.4
    • argon2-cffi==21.3.0
    • argon2-cffi-bindings==21.2.0
    • ase==3.22.1
    • asgiref==3.6.0
    • astor==0.8.1
    • attrs==22.2.0
    • autobahn==19.5.1
    • autograd==1.5
    • autograd-gamma==0.5.0
    • automat==22.10.0
    • beautifulsoup4==4.12.2
    • bleach==3.1.5
    • brotli==1.0.9
    • cached-property==1.5.2
    • channels==2.3.1
    • click==8.1.6
    • cloudpickle==2.2.1
    • constantly==15.1.0
    • cycler==0.11.0
    • daphne==2.4.1
    • debugpy==1.6.7
    • defusedxml==0.7.1
    • django==2.2.28
    • djangorestframework==3.9.4
    • docker-pycreds==0.4.0
    • ecos==2.0.12
    • exceptiongroup==1.1.1
    • fastjsonschema==2.17.1
    • feather-format==0.4.1
    • formulaic==0.6.1
    • future==0.18.3
    • gast==0.5.3
    • gitdb==4.0.10
    • gitpython==3.1.32
    • google-pasta==0.2.0
    • googledrivedownloader==0.4
    • graphlib-backport==1.0.3
    • grpcio==1.51.3
    • gym==0.26.2
    • gym-notices==0.0.8
    • h5py==2.10.0
    • hyperlink==21.0.0
    • importlib-metadata==4.13.0
    • incremental==22.10.0
    • inflate64==0.3.1
    • interface-meta==1.3.0
    • ipykernel==6.16.2
    • ipywidgets==8.0.6
    • isodate==0.6.1
    • jinja2==3.1.2
    • joblib==1.2.0
    • jsonfield2==3.0.3
    • jsonschema==3.0.2
    • jupyter==1.0.0
    • jupyter-client==7.4.9
    • jupyter-console==6.6.3
    • jupyter-core==4.12.0
    • jupyter-server==1.24.0
    • jupyterlab-pygments==0.2.2
    • jupyterlab-widgets==3.0.7
    • keras-applications==1.0.8
    • keras-preprocessing==1.1.2
    • kiwisolver==1.4.4
    • lifelines==0.27.7
    • llvmlite==0.39.1
    • lz4==4.3.2
    • markdown==3.4.1
    • markupsafe==2.1.2
    • matplotlib==3.1.1
    • mistune==2.0.5
    • multivolumefile==0.2.3
    • mypy-extensions==0.4.4
    • nbclassic==1.0.0
    • nbclient==0.7.4
    • nbconvert==7.5.0
    • nbformat==5.8.0
    • nest-asyncio==1.5.6
    • networkx==2.6.3
    • nmslib==2.1.1
    • notebook==6.5.4
    • notebook-shim==0.2.3
    • numba==0.56.4
    • numexpr==2.8.4
    • observable==1.0.3
    • opencv-python==4.7.0.72
    • openslide-python==1.2.0
    • openslides==3.3
    • osqp==0.6.3
    • packaging==23.0
    • pandas==1.3.5
    • pandocfilters==1.5.0
    • pathtools==0.1.2
    • prometheus-client==0.17.0
    • protobuf==3.20.0
    • psutil==5.9.5
    • py7zr==0.20.5
    • pyarrow==12.0.1
    • pyasn1==0.4.8
    • pyasn1-modules==0.2.8
    • pybcj==1.0.1
    • pybind11==2.6.1
    • pycox==0.2.3
    • pycryptodomex==3.18.0
    • pyparsing==3.0.9
    • pypdf2==1.26.0
    • pyppmd==1.0.0
    • pyrsistent==0.19.3
    • python-louvain==0.16
    • pytz==2023.3
    • pyyaml==6.0
    • pyzmq==25.1.0
    • pyzstd==0.15.9
    • qdldl==0.1.7
    • qtconsole==5.4.3
    • qtpy==2.3.1
    • rdflib==6.3.2
    • roman==3.1
    • scikit-learn==1.0.2
    • scikit-survival==0.17.2
    • scipy==1.7.3
    • send2trash==1.8.2
    • sentry-sdk==1.28.1
    • service-identity==21.1.0
    • setproctitle==1.3.2
    • setuptools==41.6.0
    • smmap==5.0.0
    • sniffio==1.3.0
    • soupsieve==2.4.1
    • sqlparse==0.4.3
    • tensorboard==1.14.0
    • tensorboardx==1.9
    • tensorflow-estimator==1.14.0
    • tensorflow-gpu==1.14.0
    • termcolor==2.2.0
    • terminado==0.17.1
    • texttable==1.6.7
    • threadpoolctl==3.1.0
    • tinycss2==1.2.1
    • torch-geometric==2.3.1
    • torch-scatter==2.1.1
    • torch-sparse==0.6.17
    • torchtuples==0.2.2
    • tqdm==4.65.0
    • twisted==22.10.0
    • txaio==23.1.1
    • wandb==0.15.5
    • webencodings==0.5.1
    • websocket-client==1.5.3
    • websockets==8.1
    • werkzeug==2.2.3
    • widgetsnbextension==4.0.7
    • wrapt==1.15.0
    • zipp==3.15.0
    • zope-interface==6.0
      prefix: /home/ubuntu/anaconda3/envs/clam
Raymvp commented

我使用patchgcn也出现过类似问题。据之前的实验经历,建议您检查这样几个实验参数:

  1. 降低每个batch 内patch的总数,一般1500左右是没问题的。
  2. 设置GCN层数为1。
  3. 降低embedding 的维度,比如把384改为128,或256。

刘博士您好,我最近又在做这方面的工作。我还是发现默认设置下,patchGCN容易爆显存,我用的是一张4090,值得一提的是,您的实验是使用2xV100,原作者是使用4x2080,我有个疑问,在batchsize=1单张显卡爆显存的情况下,是否有办法叠加多张显卡,解决爆显存的问题。