/BayesSMILES

Bayesian Segmentation Modeling for Longitudinal Epidemiological Studies

Primary LanguageC++

BayesSMILES

Bayesian Segmentation ModelIng for Longitudinal Epidemiological Studies

Tutorial

The following script is used to apply the Bayesian hierarchical model to detect multiple change points based on the daily active infectious cases of COVID-19, while estimating the basic reproductive number R_0 between all pairs of adjacent change points.

Load input data

We load the input data for the 50 US states

# load input data information #
load("../Data/USstates_model_input.Rdata")
# The input data for a state (e.g. Texas) is shown below #
knitr::kable(input_data_list[["Texas"]][1:3, ])
I_obs R_obs C_obs population recovery_rate time date state
489 154 643 28996 0.1 1 2020-03-22 Texas
555 203 758 28996 0.1 2 2020-03-23 Texas
696 259 955 28996 0.1 3 2020-03-24 Texas

The first three columns are

  • “I_obs”: daily actively infecious cases
  • “R_obs”: daily removed/recovered cases
  • “C_obs”: daily confirmed cases

Select change points and estimate the basic reproduction number

We first load the functions used for our model.

# load R packages
library(Rcpp)
library(loo)
library(mcclust)

# load functions #
source("../functions/functions_JHU.R")
sourceCpp("../functions/covid_poi_fast.cpp")
sourceCpp("../functions/core_5m_sir.cpp")

Next, we show how to apply the proposed model into the COVID-19 data for Texas.

input_dataframe = do.call(rbind, input_data_list)
state_idx = which(names(input_data_list) == "Texas")

# detect the change point for Texas #
cp_result = detect_changepoint(input_dataframe, 
                               NITER = 10000,
                               state_idx = state_idx, 
                               h_vec = c(100000, 10),
                               seed = 1037210)

# estimate the basic reproduction number based on the change point locations #
sir_result = fit_SIR(ret_opt = cp_result)

The cp_result is a list containing the change point detection results. It has three elements:

  1. The best change point detection result determined by the leave-one-out cross validation
  2. The input data setting for the corresponding state
  3. Setting information for the model fitting

The sir_result contains the output from the SIR model. It contains:

  1. R_0 estimation for each segment and the corresponding 95% credible interval
  2. A 7-day prediction for the daily new COVID-19 cases based on the data in the last segment
  3. A 7-day prediction for the daily new COVID-19 cases based on the full data

Visualize the model output

Here, we demonstrate our model results using the COVID-19 data for Texas. Our results include two major component: (1) change piont detection, and (2) basic reproduction number R_0 estimation

# 1. daily actively infecious counts and date informtion #
n_record = nrow(cp_result$data)
start_day = cp_result$data$date[1]
end_day = cp_result$data$date[n_record]
ref_time = cp_result$data$date
state.nam = unique(cp_result$data$state)
population = cp_result[[1]]$population
time_x_label_red = format(cp_result$data$date, format="%m-%d")
ticks_x_pos = seq(1:n_record)[unlist(lapply(ref_time, function(x){x %in% cp_result$data$date}))]
plot_df = data.frame(count = cp_result$data$I_obs, 
                    time = cp_result$data$date,
                    stringsAsFactors = F)
plot_df$count = as.numeric(plot_df$count)
plot_yval = log((plot_df$count + 1)/population)
plot_df$idx = seq(1, nrow(plot_df))
cp_res = cp_result[[1]]
loga_df = cp_res$loga_est
y_range = range(loga_df, plot_yval[!is.na(plot_yval)])
if(max(y_range) > 0){
  y_range = c(min(y_range) * 1.2, max(y_range) * 1.2)
}else{
  y_range = c(min(y_range) * 1.2, max(y_range) / 1.2)
}

# 2. change point location (with confidence interval) information #
gamma_ppm = cp_res$gamma_ppm
cp_df = data.frame(gamma = gamma_ppm, 
                    time = cp_result$data$date,
                    lga_mean = loga_df$log_alpha,
                    lga_lower = loga_df$`2.5%`,
                    lga_upper = loga_df$`97.5%`,
                    stringsAsFactors = F)
gamma_plot = as.numeric(cp_df[, 1])

# 3. R_0 estimation for each segment 
SIR_res = sir_result[[1]]$full_sir
cp_seg = sum(gamma_plot)
z_vec = cumsum(gamma_plot)
seg_size = as.numeric(table(z_vec))
seg_plot = c(which(gamma_plot == 1), n_record)
seg_len = as.numeric(table(cumsum(gamma_plot)))
seg_len_rep = rep(seg_len, seg_len)
SIR_df = data.frame(z = cumsum(gamma_plot),
                    R0 = rep(SIR_res[, 1], seg_len ))

# 4. generate the plot #
par(mar = c(5, 5, 3, 5))
plot(plot_df$idx, plot_yval, ylim = y_range,
     type = 'b', xaxt="n", xlab = "", pch = 16,
     ylab = 'log(actively infectious cases / population)', 
     main = state.nam)
axis(side = 1, at = ticks_x_pos[seq(1, length(ticks_x_pos), by = 5)],
     labels = time_x_label_red[seq(1, length(ticks_x_pos), by = 5)], las=2)

# (1) Confidence intervals for each change point #
gamma.idx = which(gamma_plot == 1)[-1]
gamma.ci.list = cp_result[[1]]$change_point_details
for(ii in 1:length(gamma.idx)){
  cp.detail = gamma.ci.list[[ii]]
  rect(min(cp.detail$cp.ci),min(y_range),
       max(cp.detail$cp.ci),
       max(y_range),border = NA,
       col = rgb(0.5,0.5,0.5,1/4))
}

# (2) R_0 estimation for each segment 
abline(v = which(gamma_plot == 1), col = 'gray30', lwd = 2)
invisible(lapply(1:(cp_seg), function(i){
  x_pos = mean(c(seg_plot[i], seg_plot[i + 1]));
  y_pos = min(y_range) * 0.85;
  text(x_pos, y_pos, format(round(SIR_res[i, 1], 2), nsmall = 2),
         col = 'red' )
  
}) )

invisible(lapply(1:(cp_seg), function(i){
  x_pos = mean(c(seg_plot[i], seg_plot[i + 1]));
  y_pos = min(y_range) * 0.9;
  text(x_pos, y_pos, format(round(SIR_res[i, 2], 2), nsmall = 2),
         cex = 0.75,col = 'red' )
}) )

invisible(lapply(1:(cp_seg), function(i){
  x_pos = mean(c(seg_plot[i], seg_plot[i + 1]));
  y_pos = min(y_range) * 0.8;
  text(x_pos, y_pos, format(round(SIR_res[i, 3], 2), nsmall = 2),
         cex = 0.75,col = 'red' )
}) )

# (3) Trend of change in R_0 #
par(new = TRUE)
plot(plot_df$idx,SIR_df$R0, type = "l", xaxt = "n", yaxt = "n",
     ylab = " ", xlab = "", col = "red", lty = 2, lwd = 3)
axis(side = 4)
mtext("R0 estimation", side = 4, line = 3)