greenelab/deep-review

Deep Survival Analysis

sw1 opened this issue · 9 comments

sw1 commented

http://arxiv.org/abs/1608.02158

The electronic health record (EHR) provides an unprecedented opportunity to build actionable tools to support physicians at the point of care. In this paper, we investigate survival analysis in the context of EHR data. We introduce deep survival analysis, a hierarchical generative approach to survival analysis. It departs from previous approaches in two primary ways: (1) all observations, including covariates, are modeled jointly conditioned on a rich latent structure; and (2) the observations are aligned by their failure time, rather than by an arbitrary time zero as in traditional survival analysis. Further, it (3) scalably handles heterogeneous (continuous and discrete) data types that occur in the EHR. We validate deep survival analysis model by stratifying patients according to risk of developing coronary heart disease (CHD). Specifically, we study a dataset of 313,000 patients corresponding to 5.5 million months of observations. When compared to the clinically validated Framingham CHD risk score, deep survival analysis is significantly superior in stratifying patients according to their risk.

It's based on the deep exponential family stuff Blei's lab has been working on: https://arxiv.org/abs/1411.2581

@sw1: Are you willing to provide a summary & discussion of the paper? Something like the start of the conversation from #43 or #10?

sw1 commented

@cgreene Sure! I intend to, but I still should have posted the abstract before I get around to it, right?

sw1 commented

Summary

The paper has a novel take on survival analysis, using deep exponential families. I think the big take-home here is the flexibility DEFs, and in turn, DSA has for modeling various data types and making use of a variety of priors. On an oversimplified level, it does come off as a gigantic hierarchical GLM, like something you'd see coded in Stan, but just much bigger. The non-linear structure it can capture clearly sets it apart.

With that said, DEFs seem interesting, particularly for unsupervised count data. The paper presents, as one of its examples, a deep topic model, such that you have topic, super-topics, and concepts. A lot of data can be modeled in terms of count frequencies (metagenomic data and gene counts, for example), so I'm curious about the type of features that would result compared to, say, LDA. It's final example aimed to show how DEFs can be "embedded and composed in more complex models," so they modeled pairwise data.

Background

This paper is based on deep exponential families (DEFs) – also from the Blei lab. DEFs are fully Bayesian latent variable models with similarities to hierarchical modeling (“random effects” models and “multilevel” models), as well as deep neural networks. It should be intuitive if you are familiar with GLMs. Effectively, layers are linked together using a link function g (in the same fashion one uses a “log-link” in a Poisson GLM to map the inner product of the features and weights to the Poisson’s necessary natural parameter) and an exponential family distribution; thus, p(z_L|z_{L+1}, w_L) = EXPFAML(z_L, g(z_{L+1}^T W_L)). The observation layer is then p(x|z_1,W_0) = P(z_1^T w_0), and these observations can be modeled as counts, reals, or modeled as binary or multinomial. The choice of the observation model and the prior placed on the weights W is analogous to choosing the type of regularization and loss function in a neural network. Posterior approximation is achieved via variable inference.

Deep Survival Analysis (DSA)

Aim

Apply DSA to EHR data to predict disease risk

Justification:

Current survival analysis based on regression is limited by

  1. Missing data and sparsity
    
  2. A lack of a common starting point across records
  3. Non-linearities inherent to the data

DSA, on the other hand, is (1) fully Bayesian, (2) aligns patients with respect to their failure time, and (3) is based on DEFs, providing non-linear latent structure.

Methods

  • EHRs spanning 313,000 patient records to assess risk of coronary heart disease (CHD)
  • Covariates included reals (labs and vitals) and counts (medications and diagnosis codes)
  • Used a Gaussian for the EXPFAM distribution, where the mean and precision for each layer are 2 layer perceptrons with ReLs; Gaussian priors on the weights and intercepts; a Weibull distribution for the observation model, which is a distribution common in survival analysis; a t distribution on real valued covariates to combat outliers; and a binary model on counts, such that the regression coefficients are given a log-Normal prior, and the inner product of a latent variable z with these coefficients is mapped to 0 or 1.
  • Because of a lack of consistent start time across records, they measured time starting at the event (i.e., the “failure” in traditional survival analysis) and worked backwards.
  • Data were labeled based on the time interval and whether the observation was censored (terminated before the failure).
  • Only 11.8% of patients had a complete set of the most basic, critical set of variables (LDL, HDL, BP) present in their record for a given month; only 1.4% of all months haf this complete set.
  • Training (n=263000) took 7.5 hours (6000 iterations) on a 40-core, 384 GB RAM server
  • Tested multiple latent variable z sizes K (5, 10, 25, 75, 100) with matching layer sizes.

Results

  • The best concordance a DSA model parameterization had on a heldout set: K=50
  • The best covariate at predicting CHD risk: diagnosis codes

Because of a lack of consistent start time across records, they measured time starting at the event (i.e., the “failure” in traditional survival analysis) and worked backwards.

@sw1 For patients who survive, do they use the last observation and work backwards?

sw1 commented

@agitter Effectively yeah.

Data (say a patient A with a failure and a patient B censored with unknown failure):
A: _start_i--------------------------------failure
B: ____start_j----------------censor

DSA:
A: failure--------------------------------start_i
B: censor----------------start_j

A and B both would have two variables in the model: (1) the intervals (time from failure/censor to start) and (2) whether it was a failure or censor, so

A = (32, 0) [please tell me why I actually counted the dashes]
B = (16,1)

One of the nice statistical features of setting the data up this way is that they're now exchangeable. Lastly, it says that the likelihood for a censored observation is the amount of probability remaining after censoring, 1-cdf(x).

Closing issue since discussion on it seems to have stopped. Agree that it makes sense to talk about this in the paper. @sw1 : do you have the lead on discussing this work as an example of how we categorize disease?

sw1 commented

@cgreene Yeah. I can do that!

Is there a deep-net version of this method?
From my understanding, the "deep" aspect of this work is in latent exponential families, but is not deep in the sense of neural networks (e.g. several hidden layers + regularization etc).

@ibarrien I don't recall seeing discussion of any follow up work to this paper in our issues here

@sw1 are you aware of any?