Automatic Site Selection for Acute respiratory distress syndrome (ARDS)

Click to Watch Initial Video Click to Watch Final Demo

Screenshot-2023-05-14-at-8-13-29-PM

ARDS Clinical Trials

Comprehensive analysis and visualization of ARDS (Acute Respiratory Distress Syndrome) clinical trial locations across the USA. For the final version of the repository, head to [https://github.com/SUMEETRM/GEn1E](https://github.com/SUMEETRM/GEn1E).

ARDS (Acute Respiratory Distress Syndrome) is a lung condition that leads to fluid accumulating in the alveoli, thereby depriving organs of oxygen. It often occurs in critically ill individuals or those with injuries and is frequently fatal. To date, no analysis has been conducted on the state-level or county-level prevalence of ARDS in the US, making the process of determining clinical trial locations subjective and prone to failure. This paper presents a standardized method, grounded in machine learning, to ascertain the prevalence of ARDS using multiple cause factors. It introduces a platform to identify clinical trial locations across the US in a manner that leverages almost all publicly available data. Furthermore, the paper explains this entire methodology in a way that can be readily applied to other diseases, enabling researchers to locate the statistically optimal locations for clinical trials.

Table of Contents

📁 Data Sources

The ards_data folder in this repository hosts several key datasets that are instrumental for our analysis on Acute Respiratory Distress Syndrome (ARDS) and associated factors.

Here is a breakdown of the datasets included:

  • 🏥 ARDS_centers: Contains names and geographic coordinates of hospitals known for their specialization in ARDS treatment. Data is originally sourced from ClinicalTrials.gov.

  • 📈 County_COPD_prevalence: A dataset from the CDC that provides information about the prevalence of Chronic Obstructive Pulmonary Disease (COPD) at a county level.

  • 🏊 Drowning_data: A dataset providing insights into drowning incidents, sourced from the CDC.

  • 🤧 flu_data: This dataset contains data on flu cases, as provided by the CDC.

  • 🤒 Pnem: Includes data related to pneumonia cases, also sourced from the CDC.

  • 😷 Covid_nyt: Contains detailed COVID-19 case data, sourced from the New York Times.

  • 🚑 Sepsis1: CDC-sourced dataset containing detailed case data on sepsis.

  • 💉 Vaccination: Includes data related to flu vaccination rates.

  • 🗺️ uscounties: A comprehensive dataset containing details about US county coordinates, population figures, and FIPS codes.

  • 🚬 Tobacco: A dataset that provides insights into tobacco usage rates.

  • 🇺🇸 States.csv: A dataset containing the names, abbreviations, and central coordinates of US states.

Refer to the original data sources for more comprehensive details and context about these datasets.

Folium

The project leverages Folium, a Python library for creating interactive maps, to visualize the geographical distribution of clinical trial locations. The interactive maps provide a comprehensive overview of the trial sites and their distribution across different regions.

📂 Random Forest for Feature Importance

This directory houses a Python script that employs the Random Forests technique as a means of calculating the optimal importance of various risk factors in our Acute Respiratory Distress Syndrome (ARDS) research.

We've transitioned to the use of random forests from a previous constrained optimization approach (SLSQP) to achieve higher R-squared values, reflecting an improved fit of the model to the data.

🔽 Here is a brief overview of the key steps in the script:

  1. Import required libraries (Pandas, Numpy, Sklearn)
  2. Load the dataset 'state_data_1.csv'
  3. Normalize the risk factors and mortality rates using StandardScaler from Sklearn
  4. Modify the 'normalized_vaccination' column since higher vaccination rates imply lower rates of ARDS according to studies
  5. Define the features (risk factors) and target (mortality rates)
  6. Initialize and fit the Random Forest Regressor with the defined features and target
  7. Run predictions using the trained model and the features
  8. Calculate and print the R-squared score as a measure of the model's fit to the data
  9. Extract and return the feature importances as determined by the Random Forest model, which are the optimal weights to be used for the various risk factors in further analysis.

Our shift to random forests has enhanced our model's performance, and the feature importances computed by the Random Forest Regressor offer more robust insights into which risk factors contribute most to ARDS mortality. These weights, or importances, are then used to generate a heatmap for visual interpretation of our results.

📂 Repository Structure

This repository is structured in a way that makes the analysis flow intuitive and easy to follow. Here's a quick tour:

  • 📓 ards_map.ipynb: This is the primary Jupyter notebook where all the code runs. It integrates various components of the project and generates a detailed heatmap in HTML format. The code in this notebook is modular, making it easy to reuse and adapt to your needs.

  • ⚖️ weight_optimization.py: This Python script contains an optimization algorithm that calculates the best-suited weights for generating the weighted Folium heatmap. The weights are derived from a supervised learning algorithm optimized against an ARDS mortality study.

  • 🌐 usa_map.html: This is a ready-to-use HTML map showcasing 500 clinical trial locations in a heatmap format. It's a direct product of running the ards_map.ipynb notebook.

  • ⚙️ dataloader.py: This Python script is responsible for loading all the necessary data from the ards_data folder.

  • 📁 ards_data: This folder is the data repository hosting several datasets used in the analysis. For a more detailed description of the individual datasets, please refer to the Data Sources section.

  • 📑 ards_state_vals: This file contains state-wise ARDS mortality rates, as derived from a linked study.

  • 📄 ARDS_locations.csv: This CSV file offers comprehensive data on various studies, including details such as the number of participants, participant age and sex, study dates, and updates on study completion.

Explore each of these components to get a better understanding of the project

Methods

Identification of cause and risk factors

In order to solve the key challenges and develop a standardized framework, we describe the following methodology. This can be summarized as follows: Cause factors and risk factors were obtained through peer-reviewed medical journals or reputed government agencies discussing ARDS. This led to a list of 8 cause and risk factors through a combination of these sources.

  • Smoking
  • Pneumonia
  • Sepsis
  • COVID-19
  • Chronic lung diseases such as COPD that share risk factors with ARDS
  • Drowning
  • Influenza
  • Vaccination Rates (Low vaccination rates)

However, a significant issue arises as these factors are neither ranked nor weighted, and some even originate from different sources. Therefore, following data collection, the methodology establishes a standardized and statistically supported approach to integrating this information.

A list of reputed academic journals is also given in Supplementary Materials.

Data Collection

Data for these cause factors and risk factors was collected from sources such as the Centers for Disease Control and Prevention and reputed studies from peer-reviewed journals. Data can be as granular as county level and ideally this level of granularity would boost our analysis. However, it is often the case that such data is only available at a state level or not at all. Therefore, statistics from government datasets and national level studies were found and used independently for each of these data sources, and this method is outlined in the flowchart below.

Data Processing

Data was collected from the sources mentioned above, cleaned, and converted to CSV files for compatibility with Python's Pandas library. All states and counties in the U.S. were included. When only state-level data was available for certain factors, it was generalized to all counties within that state. The cleaning process also involved the review of anomalies, which were not unnecessarily removed. In this study, the only instances where data was removed were those areas for which no data existed, thus ensuring consistency. Following this, the data was normalized using the MinMax normalization method, mathematically represented as: Normalized Value = (Value - Min) / (Max - Min). Finally, the data was integrated based on state and county names, an essential step considering some counties in different states share the same name.

Heatmap Development

A method of generating a weighted heatmap was implemented. While there are libraries such as folium, geopandas, arcpy, etc. these methods fail to successfully merge and integrate several heatmaps on top of each other, regardless of the weights. This is because at a county level granularity, merging heatmaps leads to visual deficits and makes the platform less interpretable.

A method was therefore developed to pre-compute final outputs. This process involves assigning weights to each feature and using these pre-computed values to calculate the final values for the heatmap on a per-county basis. Subsequently, the Folium and MapBox libraries in Python were used to create heatmaps with these combined values at a high level of granularity. This approach allows for significant detail when zooming into the map, while still maintaining color consistency and displaying more information based on the zoom setting.

Maintaining consistency during pre-computation

A toggle was added to allow users to select which data sources to pick and plot the heatmap based on. In that scenario, the weights of unselected features were converted to 0. Everything was made to be user friendly, and interpretability was of key importance to allow users to view individual features one at a time and add them on top of each other to see how their combined prevalence affects ARDS.

Weights

The assignment of weights, where a key innovation of this paper lies, does not have a clear methodology for assigning importance to each feature. For conditions like ARDS, data is only available at the state level, and this data pertains to mortality, not incidence. This could then be influenced by inaccuracies in reports - a point that requires further input from Sham. This is why it is important to consider cause and risk factors. Essentially, ARDS is diagnosed based on these factors when patients meet the Berlin Criteria.

Approach and problem

An issue with simple optimization is the lack of data available at a granular level. This paper therefore followed the approach of optimizing these weights at a state level and translating these values to counties within their respective states, thus assigning importance to features at a county level to predict the prevalence of ARDS on both small and large scales. This approach was also validated by assessing weight ratios directly on county level data, and these tables are available in Supplementary Materials.

Supervised Machine Learning

Therefore, instead of eyeballing the weights, this paper developed a statistically sound supervised machine learning method that uses XGBoost, a gradient boosting technique. The weighted sum of all eight factors was used on the x-axis, while ARDS mortality rates by state were used on the y-axis. The model was implemented across 50 rows, corresponding to the number of states in the US. These weights constituted the eight variables, and gradient boosting was applied to regress across the best values to assign importances to all features. Consequently, each feature was assigned a weight representing its relative importance in predicting the response variable, as determined by the gradient boosting algorithm.

Process

  1. MinMaxScaler: We used this as a preprocessing step to normalize the data. It scales and translates the features to a 0 to 1 range. Mathematically, for each feature $X_j$, the normalized feature $X_j'$ is calculated as: $X_j' = \frac{{X_j - \min(X_j)}}{{\max(X_j) - \min(X_j)}}$ where $\min(X_j)$ and $\max(X_j)$ are the minimum and maximum of $X_j$, respectively. This operation transforms the features to fall within the range $[0, 1]$.

  2. XGBoost Gradient Boosting: XGBoost is an optimized distributed gradient boosting library designed to be highly efficient, flexible, and portable. It operates by constructing a sequence of trees, where each successive tree is built for the prediction residuals of the preceding tree, to ultimately output the sum of the predictions from individual trees for regression problems. The prediction from an XGBoost model is: $F(x) = \sum_{k=1}^{K} f_k(x)$, where $f_k(x)$ are the individual weak learners (decision trees), and $K$ is the total number of trees.

  3. Feature Importances: The feature importances of an XGBoost model are calculated based on the improvement in accuracy brought by a feature to the branches it's on. The more improvement it brings, the more important the feature is. It is computed as: $FI_j = \sum_{t=1}^{T} \frac{G^2}{H + \lambda}$, where $G$ is the gradient statistic, $H$ is the hessian statistic, and $\lambda$ is the regularization term. This allows the feature importances to be used to weigh all the factors that we have in the optimization process.

  4. R-Squared Score: The R-squared score represents the proportion of the variance for a dependent variable that's explained by independent variables in a regression model. It is calculated as $R^2 = 1 - \frac{\sum (y_{\text{true}} - y_{\text{pred}})^2}{\sum (y_{\text{true}} - y_{\text{mean}})^2}$, where $y_{\text{true}}$ are the true target values, $y_{\text{pred}}$ are the predicted target values by the model, and $y_{\text{mean}}$ is the mean of $y_{\text{true}}$.

Design choice for using XGBoost

  • Non-Linearity: XGBoost creates a non-linear decision boundary by constructing an ensemble of decision trees. Each decision tree, identified by $k$, maps an input instance $x$ to an output prediction $f_k(x)$. The $k$-th decision tree is a piecewise constant function, and the final prediction is the sum of these: $F(x) = \sum_{k=1}^{K} f_k(x)$, which allows for non-linear decision boundaries.

  • Handling High Dimensionality: XGBoost can efficiently manage high-dimensional data through its column block and parallelizable architecture. Given a $d$-dimensional input vector $x$ in $\mathbb{R}^d$, XGBoost can process each dimension in parallel during the construction of each tree, significantly reducing computational cost. This is also of vital importance since the data we're working with is very high dimensional compared to the actual rows of data we have.

  • Regularization: XGBoost extends the traditional gradient boosting framework with a regularization term that penalizes complex models, helping to prevent overfitting. The objective function in XGBoost is given by: $\text{Obj}(\Theta) = \sum_{i=1}^{n} l(y_i, \hat{y}i) + \sum{k=1}^{K} \Omega(f_k)$, where $l$ is the loss function, $y_i$ are the true labels, $\hat{y}_i$ are the predicted labels, and $\Omega(f_k)$ is the regularization term which could be L1 ($|w|_1$) or L2 ($|w|_2$) regularization.

  • Handling Sparse Data: XGBoost uses a sparsity-aware algorithm to find splits on features, which can efficiently handle sparse data. This works well for our use case since we only have a total number of rows equal to the number of states in the US.

  • Computational Efficiency: XGBoost uses techniques for parallel learning and cache-aware block structure for out-of-core computing to enhance its computational efficiency. This makes it computationally efficient, which is essential when dealing with high dimensional data.

Other methods that were analyzed

  • Decision Tree Algorithms: Random Forests were used in our experimentation—these are meta-estimators that fit a number of decision tree classifiers on subsets of the dataset and use averaging to improve prediction accuracy.
  • Boosting Algorithms: Gradient Boosting builds new methods that predict the residuals of errors of prior models and then sums them up to make the final prediction.
  • Support Vector Machines Regression (SVMR): SVMs find the hyperplane that best separates the features into different classes.
  • Neural Networks: They’re set up as a connected set of nodes with values multiplied by weights propagated forwards and backpropagation to find the optimum results.
  • Regression Algorithms: These find a linear relationship between a dependent and one or more independent variables.
  • Constrained Optimization Algorithms: Sequential Least Squares Programming was used, where we wanted to find the minima of our objective function i.e. squared difference between weighted sum of features and state level mortality rates of ARDS. Quadratic programming deals with quadratic objective functions and linear constraints, which fits well for our data.
Method R2 fit
XGBoost 0.999
Neural Networks -0.468
SLSQP 0.104
Random Forests 0.857
Linear Regression -0.277
SVMR -0.306
Uniform 0.090

Model Design

Post normalization, the ‘vaccination’ column is subtracted from 1. This is to reverse its effect since lower vaccination rates are correlated with ARDS prevalence.

The normalization procedure is crucial in ensuring that the range of the feature values does not bias or unduly influence the model's learning process, especially in algorithms like gradient boosting, which may be sensitive to the range of input features. Moreover, it is done on both axes, ensuring consistency across the board.

The XGBoost model is trained to fit a set of predictors, denoted as (X = [X_1, X_2, ..., X_8]) corresponding to ['normalized_smoking', 'normalized_copd', 'normalized_covid', 'normalized_drowning', 'normalized_sepsis', 'normalized_flu', 'normalized_pneumonia', 'normalized_vaccination'] respectively, to a response variable (Y = ) 'state_level_ARDS_mortality'.

It is an ensemble of 100 decision trees, where each tree 'b' is created by minimizing the objective function. Each tree is grown by recursively partitioning the sample space to minimize the loss function until a stopping criterion is met. The final XGBoost model is an aggregation of these 100 trees, given by (f_{\text{XGB}}(X) = \sum_{k=1}^K f_k(X)), where (f_k(X)) represents the prediction of the (k)-th tree.

The variable importances are computed based on the total reduction of the loss achieved by each predictor, averaged over all trees. If (L(X_i)) denotes the loss function, the importance of a variable (X_j) is given by (Imp(X_j) = \frac{1}{B} \sum (L(X_i) - L(X_i|X_j))), where (L(X_i|X_j)) denotes the loss of (X_i) after splitting on (X_j).

Finally, the coefficient of determination (R^2), which measures the proportion of the variance in the dependent variable that is predictable from the independent variables, is computed. In our case, (R^2 = 1 - \frac{\sum(Y_i - f_{\text{XGB}}(X_i))^2}{\sum(Y_i - \text{mean}(Y))^2}), where (\text{mean}(Y)) is the mean value of the target variable. It essentially gives the overall goodness-of-fit of the model.

XGBoost Evaluation and Hyperparameter Tuning

n_estimators learning_rate max_depth R2
150 0.1 5 1.000
100 0.1 5 1.000
150 0.1 3 0.999
100 0.1 3 0.999
150 0.01 5 0.915
100 0.01 5 0.806
150 0.01 3 0.793
100 0.01 3 0.659

Final Feature Weights

Feature Weights
Smoking 0.09
COPD 0.10
Covid 0.17
Drowning 0.08
Sepsis 0.19
Flu 0.03
Pneumonia 0.19
Vaccination 0.13

Identifying Locations

The locations of hospitals and institutions were identified from past US clinical trials using the database available on ClinicalTrials.gov. This source lists more than 400 locations where trials for ARDS have been conducted in the past, including information about the number of patients in those studies, the type of study, among other details. The data provided are diverse, encompassing factors such as the study phase, date, demographics, sponsors, etc., offering potential avenues for future research extensions.

The locations are presented in a user-friendly clustering manner, allowing for deeper exploration upon clicking on regions, revealing more granular data. To generate these locations as markers on a map, the coordinates, specifically latitudes and longitudes, were required. The GeoPy tool was utilized to derive these. An algorithm was devised to sift through hospital names and their respective states to pinpoint accurate coordinates for every location. Details regarding this have been provided in the supplementary information section.

In addition, the data directly available on ClinicalTrials.gov, which includes information about the Principal Investigator and local contact, was integrated via a scraping algorithm that complies with the website's settings. The success of clinical trials often hinges on the number of available ICU beds in the region where the trial is being conducted, as this can indicate the number of patients a region can serve and the potential for patient recruitment. To account for this, a dataset indicating the number of ICU beds per county in the US was sourced and amalgamated with the main database. The methodology and database have been detailed in the supplementary information section and can be applied to other data sources, such as the number of ventilators, X-ray machines, and the like.

Web GUI and Server Process

For the graphical user interface (GUI), the platform was designed with an emphasis on user-friendliness to streamline the process of site identification. A toggle box was included on the side, featuring options to generate heatmaps for different features and to obtain a list of the top-ranked locations in the US. Initially, the map appears empty, offering the user the option to add cause factors, locations, and rankings. The transition of visual appearance is illustrated in the figure below. The locations are regionally clustered, and the map is interactive, permitting clicking and zooming, with animations facilitating navigation through the different locations. Upon clicking a marker cluster, a zoom-in view of all hospitals within that region is presented. Clicking on a pin then reveals an overall score, used to rank the hospitals, and data such as hospital name, location, total successful studies, number of ICU beds, and details on the most recent study at that location. The Principal Investigator’s name is also displayed as part of the clickable pin. This amalgamation of all potentially available public data, about both the disease and the sites, can expedite the entire process considerably.

With regards to the platform's functioning on a web server, a Flask Server was set up using Python. The frontend was kept lightweight, relying on HTML, JavaScript, and CSS. Toggle inputs were directed to the Flask backend, which processed them, ran the optimization algorithm on the selected features, generated the toggles, and relayed the frontend back to the browser.

Presentation

For more detailed information about the project, data sources, methodologies, and findings, please refer to the project documentation.