Musa Tahir, Afnan Nuruzzaman, Omer Chaudhry, Seong-Heon Jung
Sleep is a critical aspect of our health; it has long been established that interrupted sleep periods can cause sleep disorders like sleep apnea and insomnia. There are six broad stages of sleep: wake (W), rapid eye movement (REM) and the four non REM stages, known as N1, N2, N3, and N4 . Sleep specialists use polysomnographies (PSGs) to determine sleep stages. These records consist of a electroencephalogram (EEG), electrooculogram (EOG), anelectromyogram (EMG) and electrocardiogram (ECG). The paper this project is based on uses deep learning methods to classify EEG signals into sleep stages. This paper was not the first attempt to apply deep learning to this problem. What’s novel is that the authors used a Temporal Context Encoder (TCE) to capture temporal dependencies in extracted features that couldn’t be caught in previous DL implementations. Additionally, the authors addressed a data imbalance in an innovative way. In general, we spend different amounts of time in different sleep stages; the less time we spend in a sleep stage, the less data we have to train our model. In the past, researchers oversampled to overcome this issue. The authors of this paper instead redesigned the loss function in order to address the problem. In our implementation of this project, we switched the code base in TensorFlow from the original pytorch.
##Methodology Our dataset consisted of single-channel EEG data, which are typically divided into 30-second intervals that are each classified into one of six sleep stages. Figure 1 below shows the variation that exists across sleep stages in single channel EEGs. Preprocessing consisted of first removing unknown sleep stages from the dataset. Second, stages N3 and N4 were merged to simplify the classification problem and reduce model sensitivity to inter-scorer variability, as N3 and N4 are relatively difficult to distinguish in certain cases. Third, to concentrate our focus on only sleep stages, only 30 minutes of wake periods before and after sleep periods were included in our data.
The main architecture that we implement in our paper, i.e. the AttnSleep, utilizes attention mechanisms to classify sleep stages by analyzing single channel EEG signals. The proposed architecture consists of three main modules: a feature extraction module that leverages multi-resolution convolutional neural network (MRCNN) and adaptive feature recalibration (AFR), a temporal context encoder (TCE), and a classification layer that employs a class-aware loss function. With respect to our Feature Extraction module, the MRCNN comprises two branches of convolutional layers that use distinct kernel sizes to capture features across different frequency bands in the EEG signals. This approach enables exploration of various sleep-related frequency bands and extracts both low and high-frequency features. The original model’s convolution layers use custom numerical padding values to give the output of the convolution layers a specific shape. Unfortunately, we found that this approach makes the model very inflexible with data with different dimensionalities, since the padding values have been hard-coded with a specific input shape in mind. For example, the original AttnSleep model has two distinct versions of the MRCNN layer - one for the Physionet dataset, another for the SHHS dataset - because the two datasets have different shapes. To increase the extensibility of the original model, our model only uses “valid” or “same” padding, and therefore only requires one MRCNN for both datasets.
From the outputs of the MRCNN, the AFR layer models the inter-dependencies between the learned features and selects the most informative ones using a residual squeeze and excitation (residual SE) block. The block has two stages. The first is a standard 1D convolution step that runs the outputs of the convolution into a Batch Normalization layer and a ReLU activation function. In the second stage, the model uses a squeeze and excitation activation instead of ReLU. This consists of applying a 1D adaptive average pool to “squeeze” the input. It then passes through two fully connected layers with ReLU and Sigmoid activation respectively. Here, we had to implement a custom AdaptiveAveragePooling1D Keras layer, since there is no builtin TensorFlow equivalent to PyTorch’s AdaptiveAvgPool1D layer.
Next, the Temporal context encoder (TCE) utilizes multi-head attention with causal convolutions to effectively capture temporal dependencies within the extracted features. The TCE layer includes a multi-head attention (MHA) layer, a normalization layer, and two fully connected layers. Additionally, TCE employs two identical structures stacked together to generate the final features. Here, we managed to modernize and abstract the MHA mechanism by replacing the original model’s custom MHA mechanisms with the TensorFlow built-in MultiHeadedAttention. This allows for easier tuning of the model, not to mention potential performance improvements thanks to optimizations provided by the library. Lastly, the classification decision is done by a fully connected layer with a softmax activation function. To handle the issue of data imbalance, we utilize a class-aware cost-sensitive loss function. We design this function to effectively address the problem without adding any extra computational complexity. Specifically, we use the standard multi-class cross-entropy function as the loss function for our model:
In order to minimize the class-aware loss function and learn the model parameters, we use the tried-and-true Adam optimizer. Our model was reimplemented using TensorFlow from our paper’s Pytorch implementation on our local laptops with a batch size of 128. We did not deviate from the original model’s hyperparameters. Specifically, we utilized the Adam optimizer with the weight decay 1e-3, β values set to 0.9, 0.999 for β1 and β2 respectively, the ε value as 1e-08, with AMSGrad enabled. Our learning rate is 1e-3 then reduced to 1e-4 after 10 epochs. This is achieved through tf.keras.callbacks.LearningRateScheduler. For initializations, all convolution layers use a Gaussian distribution with a mean of 0 and a standard deviation of 0.02. Batch Normalization layers use a Gaussian distribution with mean of 0 and a standard deviation of 0.02.
Due to the aforementioned changes in the convolution layers’ padding strategy, some of the layers had a small difference in output size from the original model, going from 80 to 78 features in the channels dimension. Because the number of channels must be divisible without remainder by the number of attention heads, we used 6 instead of the original model’s 5 attention heads. The original paper uses 100 epochs with 20-fold cross validation. Unfortunately, following these exact configurations would take far too long. Therefore, we simplified the process, cutting corners where the original authors were being extra rigorous. Namely, in line with the original paper’s findings, we found that the model converges well before 100 epochs. Therefore, we trained each fold model with 40 epochs, and set up a tf.keras.callbacks.EarlyStopping to stop training if the validation loss had been stagnant or worsening for three consecutive epochs after the tenth epoch. This ensures a minimum of 10 epochs will be run, but also that we won’t necessarily have to train for the full 40 epochs if we find our model to be overfitting.
The original paper used four metrics to compare AttnSleep’s performance against (previously) state of the art models: categorical accuracy, F1-score, Cohen Kappa, and Geometric mean. Because we were reimplementing the model, we anticipated a similar performance between the original model and our model. Thus, we used two of the four metrics - weighted categorical accuracy and F1-score - to confirm that our model was performing on-par with the original. Categorical accuracy is simply defined as the number of samples that the model guessed correctly for over the total number of samples. Unlike for categorical loss, we do not apply class weights for the calculation of accuracy. F1-score is defined as follows:
Recall from lecture that precision and recall are defined as
Besides metrics for the model predictions’ quality, we also crudely measured the time for one epoch to obtain a ballpark idea for the computational cost of training. The full logs of the model can be found here
##Results
Regarding accuracy, the best “fold” version of our model performs on par with the cross-validated model presented in the paper at just over 80% categorical accuracy on the validation data. The mean validation F1 score (F1 score averaged across all five classes) also barely falls short of the original model at just over 70% vs the 75.1% of the original. Note that none of our models trained to the full 40 epoch limit. Instead, most paused training after about 20 epochs due to stagnation. We anticipate that if we were to train the full 100 epochs for all 20 folds, we may be able to achieve a virtually identical value. Regardless, the model can be seen more or less converging around epoch 20.
Unfortunately, the model suffers from a fair amount of overfitting, as can be observed in the disparity between train and validation loss and metrics. This is to be expected, especially some entries in the already small EDF78 dataset were removed at the time of implementing this paper. As the model has more samples to train on, we anticipate overfitting to become less of a problem. We also found it interesting that lowering the learning rate at epoch 10 had such a drastic effect on the metrics. We can see a rapid increase in F1 score and categorical accuracy, driving us to ask how sensitive the model would be to different learning rates.
With these figures, we have demonstrated that our rewritten AttnSleep model achieves virtually the same accuracy as the original while simultaneously streamlining the codebase. We managed to cut down the python code base by over 30%, going from nearly 1300 lines of Python code in the original to under 900 lines of code excluding visualization related code. In addition, the codebase now uses the newest version of TensorFlow (v.2.12.0) and TF-nightly rather than PyTorch 1.4, which helps futureproof the model.
##Challenges
The paper we ultimately used in this project was not the original paper we planned on reimplementing. This change in direction is elaborated on more in the reflection, but in essence, the original classification algorithm that we reimplemented using an LSTM network took too long to train. To solve this issue, we switched to a different paper that leverages an Attention-based model with completely different data processing techniques and architecture than our original paper. Our realizing the problem with the LSTM came later on in the semester, meaning we had less time than we anticipated. Regardless, we were able to find a better model in an adjacent topic, allowing us to build off of what we knew from the old model, accelerating our development process.
In general, we ran into many integration issues between different files when switching the code from PyTorch to TensorFlow, which slowed down development. Moreover, different versions of PyTorch and TensorFlow have slightly different syntax and functionalities, which led to more compatibility issues when integrating different libraries and packages. To solve these errors, we had to thoroughly parse through TensorFlow and Pytorch documentation to fully understand the differences between their implementations.
Data availability was also troublesome. We managed to locate all the data needed for the project and preprocessed two of the three used by the original paper. In the preprocessing step, we noticed that some of the data was missing. For example, the EDF78 dataset does not have 78 samples at the time of writing this report. It appears that some pieces of data were removed retroactively, but the reasons remain unclear.
Furthermore, while we found the SHHS dataset, we could not download it as the file server kept on throttling download speeds to less than 100kB/s, which made downloading the 5+ GBs of data impossible. We suspect that the file server is poorly maintained/funded.
##Reflection We were able to meet our original goal of successfully reimplementing the paper in Tensorflow. We even surpassed this base goal by streamlining our code, making it more efficient and concise. While we couldn’t download the SHHS dataset due to server constraints, we were still able to validate our model performance on the other two datasets. Overall, we are all fairly satisfied with how our project ultimately turned out. Our model behaved the way we expected it to by essentially emulating the results described in the paper. The differences we observed between our model and the paper’s model are relatively minor, primarily pertaining to aforementioned subtle design nuances that are not supported in Tensorflow or the fact that our architecture is more adaptable.
The paper we implemented in this project was actually not the original paper we had proposed we recreate for the first project proposal. The original paper we wanted to reimplement was titled “Neural network analysis of sleep stages enables efficient diagnosis of narcolepsy” (linked here: Neural network analysis of sleep stages enables efficient diagnosis of narcolepsy | Nature Communications). This paper used deep learning methods to both predict sleep stages and predict narcolepsy specifically. Unfortunately, the training time for our reimplementation of the code took far too long, which is why we decided to implement a different paper. The downside is that we didn’t get to create a deep learning model that could predict narcolepsy, which is one of the reasons we were excited to take this project in the first place. If we could do the project over again, we would do a more deep review of the literature to find a model that both met our training times requirements, and could predict sleep disorders.
There are many of potential extensions of our project that would be really interesting to investigate had we had more time to work with the model. First, we would be interested in evaluating the generalizability and robustness of the AttnSleep model on other datasets besides the ones the authors provided, in order to see if our models were overfitted. It would also be interesting to incorporate additional modalities such as heart rate variability or body movement data on top of single channel EEGs, and see if that would improve classification performance. One possible exciting application of the model would be to test the AttnSleep model on datasets that include patients with sleep disorders such as narcolepsy or sleep apnea and see if any irregularities in transitions between sleep stages could be detected. Finally, investigating the interpretability of the model and its features could be a fascinating way to gain insight into some of the underlying physiological mechanisms of sleep stages.