This codebase implements knowledge distillation approach for ECG based sleep staging assisted by EEG based sleep staging model. Knowledge distillation is incorporated here by softmax distillation and another approach by Attention transfer based feature training. The combination of both is the proposed model.
The code implementation was done with Pytorch-lightning framework inside a docker container. Dependencies used inside the docker can be found in requirements.txt
Experiments can be reproduced by following the procedure mentioned in Reproducibility section
The code will be updated with generator based dataset dataloader to tackle memory constraints.
Montreal Archive of Sleep Studies (MASS) - Complete 200 subject data used.
- SS1 and SS3 subsets follow AASM guidelines
- SS2, SS4, SS5 subsets follow R_K guidelines
Knowledge distillation framework using minor modifications in U-Time as base model.
Improvement in bottleneck features from ECG_Base model to KD_model as a result of Knowledge distillation compared to EEG_base model features.
Case 1 : KD_model predicting correctly, ECG_Base predicting incorrectly
Case 2 : KD_model predicting incorrectly, ECG_Base predicting correctly
Run train.py from 3-class or 4-class directories
To train baseline models
python train.py --model_type <"base model type"> --model_ckpt_name <"ckpt name">
To run Knowledge Distillation
- Feature Training
python train.py --model_type "feat_train" --model_ckpt_name <"ckpt name"> --eeg_baseline_path <"eeg base ckpt path">
- Feat_Temp (AT+SD+CL)
python train.py --model_type "feat_temp" --model_ckpt_name <"ckpt name"> --feat_path <"path to feature trained ckpt">
- Feat_WCE (AT+CL)
python train.py --model_type "feat_wce" --model_ckpt_name <"ckpt name"> --feat_path <"path to feature trained ckpt">
- KD-Temp (SD+CL)
python train.py --model_type "kd_temp" --model_ckpt_name <"ckpt name"> --eeg_baseline_path <"eeg base ckpt path">
Run test.py from 3-class or 4-class directories
To test from checkpoints
python test.py --model_type <"model type"> --test_ckpt <"Path to checkpoint>
Other arguments can be used for training and testing as per requirements
Splits Data in train-val-test for 4-class and 3-class cases (AASM and R_K both)
├─ Dataset_split
├── Data_split_3class_AllData30s_R_K.py
├── Data_split_3class_AllData_AASM.py
├── Data_split_AllData_30s_R_K.py
└── Data_split_All_Data_AASM.py
Run train.py with neccessary arguments for training 3-class sleep staging
├── 3_class
│ ├── datasets
│ │ ├── __init__.py
│ │ └── mass.py
│ │
│ ├── models
│ │ ├── __init__.py
│ │ ├── ecg_base.py
│ │ ├── eeg_base.py
│ │ ├── FEAT_TEMP.py
│ │ ├── FEAT_TRAINING.py
│ │ ├── FEAT_WCE.py
│ │ └── KD_TEMP.py
│ │
│ ├── test.py
│ ├── train.py
│ └── utils
│ ├── __init__.py
│ ├── arg_utils.py
│ ├── callback_utils.py
│ ├── dataset_utils.py
│ └── model_utils.py
Run train.py with neccessary arguments for training 4-class sleep staging
├── 4_class
│ ├── datasets
│ │ ├── __init__.py
│ │ └── mass.py
│ │
│ ├── models
│ │ ├── __init__.py
│ │ ├── ecg_base.py
│ │ ├── eeg_base.py
│ │ ├── FEAT_TEMP.py
│ │ ├── FEAT_TRAINING.py
│ │ ├── FEAT_WCE.py
│ │ └── KD_TEMP.py
│ │
│ ├── test.py
│ ├── train.py
│ └── utils
│ ├── __init__.py
│ ├── arg_utils.py
│ ├── callback_utils.py
│ ├── dataset_utils.py
│ └── model_utils.py