/MachineLearning_NetworkLevelAnalysis

Machine learning pipelines with K-fold cross-validation library for FC data

Primary LanguageMATLAB

Machine Learning Pipelines for FC Data

This repository provides a MATLAB library for the machine learning pipelines on functional connectivity (FC) data, specifically regarding a predictive linear support vector regression (LSVR) model, implemented by an innovative nested K-fold cross-validation to account for related subjects such as twins and siblings. The biological interpretation of the ML results can be investigated by the Network-level enrichment.

Example

  • MLCrossVal_testScript.m: This is an example on Human Connectome Project (HCP) data. There are 965 subjects in the dataset, including 420 familty groups indicated by the groupIDs. The goal is to predict the ages by using the functional connectivity data. We adopted the Gordon 13 networks for the parcellation, and 333 regions of interests (ROIs).

Folders/Scripts

  • MLCrossVal.m: This is the main function. Have a try on your data!
  • +featurefilter/: This contains the code for the optional feature filter applied before fitting the LSVR prediction model. Notably, functional connectivity data always forms a high-dimensional statistical problem, e.g. 333 ROIs gives 55,611 functional connectivity features, which is much larger than the number of subjects, saying hundreds. Therefore, an additional feature selection step is frequently adopted to avoid overfitting. We include the popular marginal Pearson correlation feature filter in HighestCorr.m, i.e., selecting N connectivity features with the highest correlations with the label (e.g. age, behavior score). If you want to use the highest corr filter, what you can do is:
    • crossValObj = mlnla.MLCrossVal(); %Cross val with default settings
    • newFilter = mlnla.featurefilter.HighestCorr(123); %Create a filter that only takes the 123 highest correlations
    • crossValObj.featureFilter = newFilter; %This sets the filter of the cross val object to the new HighestCorr filter
  • +traintestdatasplitter/: This contains the code for the training and test set splitting for the outer loop of a nested cross validation (CV). For example, one can set mlCrossVal.testDataFraction = 0.2, and then the test set will be 20% of all subjects, i.e. 80%/20% random splitting. The default splitting in MLCrossVal.m is to maintain the subjects from the same family together so they are not split between training and test sets, which is available in the function MaintainGroups.m. One can switch to random splitting if there is no family structure in the data, available in the function IgnoreGroups.m, which simply disgards the groupIDs, by doing the following:
    • crossValObj = mlnla.MLCrossVal();
    • newDataSplitter = mlnla.traintestdatasplitter.IgnoreGroups(); %create new data splitter object that will split data into training and testing sets while ignoring group IDs
    • crossValObj.trainTestDataSplitter = newDataSplitter; %set the data splitter object of the crossVal calculator to the new data splitter made in the previous line
  • +tuningmodelfitter/: This is for the inner loop of the nested cross validation mentioned above in +traintestdatasplitter/. Basically, in each outer loop, one can perform another K-fold CV in the outer-loop training set to tune the hyperparameter in the ML model (e.g., "lambda" in the Ridge penalty in LSVR model). A list of candidates for "lambda" can be specified by "lambdaTestSet" in KFoldLinearModel.m to faciliate a grid search using the least-square loss to choose an optimal lambda.

References

  • Jiaqi Li, Ari Segel, Xinyang Feng, Jiaxin Cindy Tu, Andy Eck, Kelsey King, Babatunde Adeyemo, Nicole R. Karcher, Likai Chen, Adam T. Eggebrecht, Muriah D. Wheelock. Network level analysis provides a framework for biological interpretation of machine learning results. (2023). [major revision requested by Network Neuroscience]