torch_ecg
ECG Deep Learning Framework Implemented using PyTorch.
The system design is depicted as follows
Installation
torch_ecg
requires Python 3.6+ and is available through pip:
python -m pip install torch-ecg
One can download the development version hosted at GitHub via
git clone https://github.com/DeepPSP/torch_ecg.git
cd torch_ecg
python -m pip install .
or use pip directly via
python -m pip install git+https://github.com/DeepPSP/torch_ecg.git
Main Modules
Augmenters
Click to expand!
Augmenters are classes (subclasses of torch
Module
) that perform data augmentation in a uniform way and are managed by the AugmenterManager
(also a subclass of torch
Module
). Augmenters and the manager share a common signature of the formward
method:
forward(self, sig:Tensor, label:Optional[Tensor]=None, *extra_tensors:Sequence[Tensor], **kwargs:Any) -> Tuple[Tensor, ...]:
The following augmenters are implemented:
- baseline wander (adding sinusoidal and gaussian noises)
- cutmix
- mixup
- random flip
- random masking
- random renormalize
- stretch-or-compress (scaling)
- label smooth (not actually for data augmentation, but has simimlar behavior)
Usage example (this example uses all augmenters except cutmix, each with default config):
import torch
from torch_ecg.cfg import CFG
from torch_ecg.augmenters import AugmenterManager
config = CFG(
random=False,
fs=500,
baseline_wander={},
label_smooth={},
mixup={},
random_flip={},
random_masking={},
random_renormalize={},
stretch_compress={},
)
am = AugmenterManager.from_config(config)
sig, label, mask = torch.rand(2,12,5000), torch.rand(2,26), torch.rand(2,5000,1)
sig, label, mask = am(sig, label, mask)
Augmenters can be stochastic along the batch dimension and (or) the channel dimension (ref. the get_indices
method of the Augmenter
base class).
π Back to TOC
Preprocessors
Click to expand!
Also preprecessors acting on numpy
array
s. Similarly, preprocessors are monitored by a manager
import torch
from torch_ecg.cfg import CFG
from torch_ecg._preprocessors import PreprocManager
config = CFG(
random=False,
resample={"fs": 500},
bandpass={},
normalize={},
)
ppm = PreprocManager.from_config(config)
sig = torch.rand(12,80000).numpy()
sig, fs = ppm(sig, 200)
The following preprocessors are implemented
- baseline removal (detrend)
- normalize (z-score, min-max, naΓ―ve)
- bandpass
- resample
For more examples, see the README file) of the preprecessors
module.
π Back to TOC
Databases
Click to expand!
This module include classes that manipulate the io of the ECG signals and labels in an ECG database, and maintains metadata (statistics, paths, plots, list of records, etc.) of it. This module is migrated and improved from DeepPSP/database_reader
After migration, all should be tested again, the progression:
Database | Source | Tested |
---|---|---|
AFDB | PhysioNet | βοΈ |
ApneaECG | PhysioNet | β |
CinC2017 | PhysioNet | β |
CinC2018 | PhysioNet | β |
CinC2020 | PhysioNet | βοΈ |
CinC2021 | PhysioNet | βοΈ |
LTAFDB | PhysioNet | β |
LUDB | PhysioNet | βοΈ |
MITDB | PhysioNet | β |
SHHS | NSRR | β |
CPSC2018 | CPSC | βοΈ |
CPSC2019 | CPSC | βοΈ |
CPSC2020 | CPSC | βοΈ |
CPSC2021 | CPSC | βοΈ |
NOTE that these classes should not be confused with a torch
Dataset
, which is strongly related to the task (or the model). However, one can build Dataset
s based on these classes, for example the Dataset
for the The 4th China Physiological Signal Challenge 2021 (CPSC2021).
One can use the built-in Dataset
s in torch_ecg.databases.datasets
as follows
from torch_ecg.databases.datasets.cinc2021 import CINC2021Dataset, CINC2021TrainCfg
config = deepcopy(CINC2021TrainCfg)
config.db_dir = "some/path/to/db"
dataset = CINC2021Dataset(config, training=True, lazy=False)
π Back to TOC
Implemented Neural Network Architectures
Click to expand!
- CRNN, both for classification and sequence tagging (segmentation)
- U-Net
- RR-LSTM
A typical signature of the instantiation (__init__
) function of a model is as follows
__init__(self, classes:Sequence[str], n_leads:int, config:Optional[CFG]=None, **kwargs:Any) -> NoReturn
if a config
is not specified, then the default config will be used (stored in the model_configs
module).
Quick Example
A quick example is as follows:
import torch
from torch_ecg.utils.utils_nn import adjust_cnn_filter_lengths
from torch_ecg.model_configs import ECG_CRNN_CONFIG
from torch_ecg.models.ecg_crnn import ECG_CRNN
config = adjust_cnn_filter_lengths(ECG_CRNN_CONFIG, fs=400)
# change the default CNN backbone
# bottleneck with global context attention variant of Nature Communications ResNet
config.cnn.name="resnet_nature_comm_bottle_neck_gc"
classes = ["NSR", "AF", "PVC", "SPB"]
n_leads = 12
model = ECG_CRNN(classes, n_leads, config)
model(torch.rand(2, 12, 4000)) # signal length 4000, batch size 2
Then a model for the classification of 4 classes, namely "NSR", "AF", "PVC", "SPB", on 12-lead ECGs is created. One can check the size of a model, in terms of the number of parameters via
model.module_size
or in terms of memory consumption via
model.module_size_
Custom Model
One can adjust the configs to create a custom model. For example, the building blocks of the 4 stages of a TResNet
backbone are basic
, basic
, bottleneck
, bottleneck
. If one wants to change the second block to be a bottleneck
block with sequeeze and excitation (SE
) attention, then
from copy import deepcopy
from torch_ecg.models.ecg_crnn import ECG_CRNN
from torch_ecg.model_configs import (
ECG_CRNN_CONFIG,
tresnetF, resnet_bottle_neck_se,
)
my_resnet = deepcopy(tresnetP)
my_resnet.building_block[1] = "bottleneck"
my_resnet.block[1] = resnet_bottle_neck_se
The convolutions in a TResNet
are anti-aliasing convolutions, if one wants further to change the convolutions to normal convolutions, then
for b in my_resnet.block:
b.conv_type = None
or change them to separable convolutions via
for b in my_resnet.block:
b.conv_type = "separable"
Finally, replace the default CNN backbone via
my_model_config = deepcopy(ECG_CRNN_CONFIG)
my_model_config.cnn.name = "my_resnet"
my_model_config.cnn.my_resnet = my_resnet
model = ECG_CRNN(["NSR", "AF", "PVC", "SPB"], 12, my_model_config)
π Back to TOC
CNN Backbones
Click to expand!
Implemented
- VGG
- ResNet (including vanilla ResNet, ResNet-B, ResNet-C, ResNet-D, ResNeXT, TResNet, Stanford ResNet, Nature Communications ResNet, etc.)
- MultiScopicNet (CPSC2019 SOTA)
- DenseNet (CPSC2020 SOTA)
- Xception
In general, variants of ResNet are the most commonly used architectures, as can be inferred from CinC2020 and CinC2021.
Ongoing
- MobileNet
- DarkNet
- EfficientNet
TODO
- HarDNet
- HO-ResNet
- U-Net++
- U-Squared Net
- etc.
More details and a list of references can be found in the README file of this module.
π Back to TOC
Components
Click to expand!
This module consists of frequently used components such as loggers, trainers, etc.
Loggers
Loggers including
- CSV logger
- text logger
- tensorboard logger are implemented and manipulated uniformly by a manager.
Outputs
The Output
classes implemented in this module serve as containers for ECG downstream task model outputs, including
ClassificationOutput
MultiLabelClassificationOutput
SequenceTaggingOutput
WaveDelineationOutput
RPeaksDetectionOutput
each having some required fields (keys), and is able to hold an arbitrary number of custom fields. These classes are useful for the computation of metrics.
Metrics
This module has the following pre-defined (built-in) Metrics
classes:
ClassificationMetrics
RPeaksDetectionMetrics
WaveDelineationMetrics
These metrics are computed according to either Wikipedia, or some published literatures.
Trainer
An abstract base class BaseTrainer
is implemented, in which some common steps in building a training pipeline (workflow) are impemented. A few task specific methods are assigned as abstractmethod
s, for example the method
evaluate(self, data_loader:DataLoader) -> Dict[str, float]
for evaluation on the validation set during training and perhaps further for model selection and early stopping.
π Back to TOC
π Back to TOC
Other Useful Tools
Click to expand!
R peaks detection algorithms
This is a collection of traditional (non deep learning) algorithms for R peaks detection collected from WFDB and BioSPPy.
π Back to TOC
Usage Examples
Click to expand!
See case studies in the benchmarks folder.
a large part of the case studies are migrated from other DeepPSP repositories, some are implemented in the old fasion, being inconsistent with the new system architecture of torch_ecg
, hence need updating and testing
Benchmark | Architecture | Source | Finished | Updated | Tested |
---|---|---|---|---|---|
CinC2020 | CRNN | DeepPSP/cinc2020 | βοΈ | βοΈ | βοΈ |
CinC2021 | CRNN | DeepPSP/cinc2021 | βοΈ | βοΈ | βοΈ |
CPSC2019 | SequenceTagging/U-Net | NA | βοΈ | βοΈ | βοΈ |
CPSC2020 | CRNN/SequenceTagging | DeepPSP/cpsc2020 | βοΈ | β | β |
CPSC2021 | CRNN/SequenceTagging/LSTM | DeepPSP/cpsc2021 | βοΈ | βοΈ | βοΈ |
LUDB | U-Net | NA | βοΈ | βοΈ | βοΈ |
Taking CPSC2021 for example, the steps are
- Write a
Dataset
to fit the training data for the model(s) and the training workflow. Or directly use the built-inDataset
s intorch_ecg.databases.datasets
. In this example, 3 tasks are considered, 2 of which use aMaskedBCEWithLogitsLoss
function, hence theDataset
produces an extra tensor for these 2 tasks
def __getitem__(self, index:int) -> Tuple[np.ndarray, ...]:
if self.lazy:
if self.task in ["qrs_detection"]:
return self.fdr[index][:2]
else:
return self.fdr[index]
else:
if self.task in ["qrs_detection"]:
return self._all_data[index], self._all_labels[index]
else:
return self._all_data[index], self._all_labels[index], self._all_masks[index]
- Inherit a base model to create task specific models, along with tailored model configs
- Inherit the
BaseTrainer
to build the training pipeline, with theabstractmethod
s (_setup_dataloaders
,run_one_step
,evaluate
,batch_dim
, etc.) implemented.
π Back to TOC
CAUTION
For the most of the time, but not always, after updates, I will run the notebooks in the benchmarks manually. If someone finds some bug, please raise an issue. The test workflow is to be enhanced and automated, see this project.
π Back to TOC
Work in progress
See the projects page.
π Back to TOC
Citation
@misc{torch_ecg,
author = {WEN, Hao and KANG, Jingsu},
title = {{torch\_ecg: An ECG Deep Learning Framework Implemented using PyTorch}},
doi = {10.5281/ZENODO.6435048},
url = {https://zenodo.org/record/6435048},
publisher = {Zenodo},
year = {2022},
copyright = {MIT License}
}
@article{torch_ecg_paper,
author = {Hao, Wen and Jingsu, Kang},
title = {Investigating Deep Learning Benchmarks for Electrocardiography Signal Processing},
doi = {10.48550/ARXIV.2204.04420},
publisher = {arXiv},
year = {2022},
journal = {arXiv preprint arXiv:2204.04420},
copyright = {Creative Commons Attribution 4.0 International}
}
π Back to TOC
Thanks
Much is learned, especially the modular design, from the adversarial NLP library TextAttack
and from Hugging Face transformers
.
π Back to TOC