diff --git a/logic/ml_action.py b/logic/ml_action.py index 8f127aa..e936602 100644 --- a/logic/ml_action.py +++ b/logic/ml_action.py @@ -6,7 +6,7 @@ from brainflow.board_shim import BoardShim # imported so decorator can recognize loaded model -from model.intent.model import SpatialAttention +import model.intent.model as model class MLAction(BaseLogic): def __init__(self, board, ema_decay=1/60): diff --git a/model/intent/autoencoder_reconstruct.png b/model/intent/autoencoder_reconstruct.png new file mode 100644 index 0000000..33d1ae4 Binary files /dev/null and b/model/intent/autoencoder_reconstruct.png differ diff --git a/model/intent/edf_parser.py b/model/intent/edf_parser.py index b9a2e27..7606ab1 100644 --- a/model/intent/edf_parser.py +++ b/model/intent/edf_parser.py @@ -23,15 +23,31 @@ def find_edf_files(directory): raw_list = list(p.map(mne.io.read_raw_edf, paths)) def get_windows(raw): + raw.load_data() + + # preprocessing + raw.notch_filter(freqs=50, method='iir') + raw.notch_filter(freqs=60, method='iir') + raw.filter(l_freq=8, h_freq=None) + events, event_id = mne.events_from_annotations(raw) if len(event_id) != 3: return None + + sfreq = raw.info['sfreq'] # Identify T0, T1, T2 selected_events = events[(events[:, 2] == event_id['T1']) | (events[:, 2] == event_id['T2']) | (events[:, 2] == event_id['T0'])] + # Create Synthetic Events to get the whole minute + start_event_sample = selected_events[0, 0] + synthetic_events = np.array([ + [int(start_event_sample + i * sfreq), 0, 1] # Each event 1 second apart + for i in range(60) + ]) + # Create epochs around these events - epochs = mne.Epochs(raw, selected_events, tmin=0, tmax=1.0, preload=True, baseline=None) + epochs = mne.Epochs(raw, synthetic_events, tmin=0, tmax=1.0, preload=True, baseline=None) # Convert epochs to NumPy arrays return epochs.get_data() diff --git a/model/intent/edf_train.py b/model/intent/edf_train.py index ffa6466..f43e513 100644 --- a/model/intent/edf_train.py +++ b/model/intent/edf_train.py @@ -1,13 +1,10 @@ import numpy as np from keras.optimizers import Adam -from keras.models import Sequential from keras.callbacks import EarlyStopping from sklearn.model_selection import train_test_split -from sklearn.preprocessing import MinMaxScaler as Scaler +from sklearn.preprocessing import StandardScaler as Scaler -from model import encoder, decoder - -import pickle +from model import auto_encoder # Load the data data = np.load('dataset.pkl') @@ -27,29 +24,27 @@ X_train, X_val = train_test_split(data, test_size=0.2) # Build the autoencoder -autoencoder = Sequential([ - encoder, - decoder -]) -autoencoder.compile(optimizer=Adam(learning_rate=0.001), loss='huber') +autoencoder = auto_encoder +autoencoder.compile(optimizer=Adam(learning_rate=0.01), loss='mse') # Define the EarlyStopping callback -early_stopping = EarlyStopping(monitor='val_loss', patience=4, restore_best_weights=True, verbose=0) +early_stopping = EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True, verbose=0) # Train the autoencoder with early stopping -batch_size = 512 +batch_size = 256 * 2 epochs = 128 fit_history = autoencoder.fit( X_train, X_train, epochs=epochs, batch_size=batch_size, validation_data=(X_val, X_val), - callbacks=[early_stopping], verbose=1 + callbacks=[early_stopping], + verbose=1 ) #Save the model print("Saving Model") -encoder = autoencoder.layers[0] -decoder = autoencoder.layers[1] +encoder = autoencoder.encoder +decoder = autoencoder.decoder encoder.save('physionet_encoder.keras') decoder.save('physionet_decoder.keras') @@ -67,18 +62,28 @@ reconstructed = autoencoder.predict(X_val) +X_val = X_val.transpose(0, 2, 1) +reconstructed = reconstructed.transpose(0, 2, 1) + i = random.randint(0, len(X_val) - 1) js = list(range(0, 64)) random.shuffle(js) -js = js[:4] -original = X_val[i][js].flatten() -reconstructed_sample = reconstructed[i][js].flatten() - -plt.figure(figsize=(12, 4)) -plt.subplot(1, 2, 1) -plt.plot(original) -plt.title('Original Data') -plt.subplot(1, 2, 2) -plt.plot(reconstructed_sample) -plt.title('Reconstructed Data') +js = js[:4] # Select 4 random channels +original = X_val[i][js] +reconstructed_sample = reconstructed[i][js] + +# Use the dark background style +plt.style.use('dark_background') + +# Create subplots for each selected channel +fig, axs = plt.subplots(len(js), 1, figsize=(9, 16)) + +# Plot the original and reconstructed signals for each channel +for idx, j in enumerate(js): + axs[idx].plot(original[idx], label='original') + axs[idx].plot(reconstructed_sample[idx], label='reconstructed') + axs[idx].set_title(f'Channel {j} Reconstruction Comparison') + axs[idx].legend(loc='upper left') + +plt.tight_layout() plt.savefig('autoencoder_reconstruct.png') \ No newline at end of file diff --git a/model/intent/model.py b/model/intent/model.py index 62abaff..e68ae3b 100644 --- a/model/intent/model.py +++ b/model/intent/model.py @@ -1,12 +1,13 @@ import tensorflow as tf import keras -from keras.models import Sequential -from keras.layers import Dense, Activation, Flatten, Multiply, BatchNormalization, Dropout, Layer -from keras.layers import SeparableConv1D, Conv1D, UpSampling1D, MaxPooling1D +from keras.models import Sequential, Model, clone_model +from keras.layers import Dense, Layer, DepthwiseConv1D, Conv1D +from keras.layers import Activation, Multiply, BatchNormalization, SpatialDropout1D, UpSampling1D, GlobalAveragePooling1D, Input +from keras.losses import MeanSquaredError as MSE, CategoricalCrossentropy ## Spatial Attention (Thanks Summer!) -@keras.saving.register_keras_serializable() +@keras.utils.register_keras_serializable() class SpatialAttention(Layer): def __init__(self, classes, kernel_size=7, **kwargs): super(SpatialAttention, self).__init__(**kwargs) @@ -26,46 +27,164 @@ def call(self, inputs): x = self.conv2(x) return Multiply()([inputs, x]) +# Noise Layer +@keras.utils.register_keras_serializable() +class AddNoiseLayer(Layer): + def __init__(self, noise_factor=0.1, **kwargs): + super(AddNoiseLayer, self).__init__(**kwargs) + self.noise_factor = noise_factor + + def call(self, inputs, training=None): + if training: + noise = self.noise_factor * tf.random.normal(shape=tf.shape(inputs), mean=0.0, stddev=1.0) + return inputs + noise + return inputs + ## Encoder and Decoder Trained on the physionet motor imagery dataset ## https://www.physionet.org/content/eegmmidb/1.0.0/ ## Thanks again to Summer, Programmerboi, Hosomi -act = 'silu' +kernel = 3 +e_rates = [1, 2, 4] +d_rates = list(reversed(e_rates)) +act = 'elu' + +## Modification of seperable convolutions to follow along this paper +## https://journalofcloudcomputing.springeropen.com/articles/10.1186/s13677-020-00203-9 +@keras.utils.register_keras_serializable() +class StackedDepthSeperableConv1D(Layer): + def __init__(self, filters, kernel_size, dilation_rates, stride=1, use_residual=False, **kwargs): + super(StackedDepthSeperableConv1D, self).__init__(**kwargs) + self.filters = filters + self.dilation_rates = dilation_rates + self.depthwise_stack = Sequential([DepthwiseConv1D(kernel_size, padding='same', dilation_rate=dr) for dr in dilation_rates]) + self.pointwise_conv = Conv1D(filters, 1, padding='same', strides=stride) + self.residual_conv = None + if use_residual: + self.residual_conv = Conv1D(filters, 1, padding='same', strides=stride) + + def call(self, inputs): + depthwise_output = self.depthwise_stack(inputs) + output = self.pointwise_conv(depthwise_output) + if self.residual_conv: + output += self.residual_conv(inputs) + return output + + def build(self, input_shape): + super(StackedDepthSeperableConv1D, self).build(input_shape) + +encoder = Sequential([ + StackedDepthSeperableConv1D(64, kernel, e_rates, 2, True), + BatchNormalization(), Activation(act), # (80, 64) + + StackedDepthSeperableConv1D(32, kernel, e_rates, 2, True), + BatchNormalization(), Activation(act), # (40, 32) + + StackedDepthSeperableConv1D(32, kernel, e_rates, 2, True), + BatchNormalization(), Activation(act), # (20, 32) -encoder = Sequential([ - SeparableConv1D(128, 3, padding='same'), - BatchNormalization(), Activation(act), MaxPooling1D(2), - SeparableConv1D(64, 3, padding='same'), - BatchNormalization(), Activation(act), MaxPooling1D(2), - SeparableConv1D(32, 3, padding='same'), - Activation(act) + StackedDepthSeperableConv1D(32, kernel, e_rates, 1, False), + Activation('linear') ]) decoder = Sequential([ - SeparableConv1D(32, 3, padding='same'), + StackedDepthSeperableConv1D(32, kernel, d_rates, 1, True), BatchNormalization(), Activation(act), UpSampling1D(2), - SeparableConv1D(64, 3, padding='same'), + + StackedDepthSeperableConv1D(32, kernel, d_rates, 1, True), BatchNormalization(), Activation(act), UpSampling1D(2), - SeparableConv1D(128, 3, padding='same'), - BatchNormalization(), Activation(act), - SeparableConv1D(64, 1, padding='same', activation='sigmoid'), -]) + StackedDepthSeperableConv1D(32, kernel, d_rates, 1, True), + BatchNormalization(), Activation(act), UpSampling1D(2), + + StackedDepthSeperableConv1D(64, kernel, d_rates, 1, False), + Activation('linear') +]) + +## AutoEncoder Wrapper for edf_train +## Tunes for both feature and reconstruction losses +class CustomAutoencoder(Model): + def __init__(self, encoder, decoder, perceptual_weight=1.0, sd_rate=0.2): + super(CustomAutoencoder, self).__init__() + self.spatial_dropout = SpatialDropout1D(sd_rate) + self.encoder = encoder + self.decoder = decoder + self.perceptual_weight = perceptual_weight + self.mse_loss = MSE() + + def call(self, inputs): + # Encoding and reconstructing the input + inputs = self.spatial_dropout(inputs) + original_features = self.encoder(inputs) + reconstruction = self.decoder(original_features) + + # get features from reconstruction + reconstructed_features = self.encoder(reconstruction) + + # Compute and add perceptual loss during the call + perceptual_loss = self.mse_loss(original_features, reconstructed_features) + self.add_loss(self.perceptual_weight * perceptual_loss) + + # Return only the reconstruction for the main loss computation + return reconstruction + +auto_encoder = CustomAutoencoder(encoder, decoder) + +## Classifier Model that is guided by pretrained Autoencoder Teacher +class StudentTeacherClassifier(Model): + def __init__(self, frozen_encoder, frozen_decoder, classes, perceptual_weight=1.0, classify_weight=1.0, **kwargs): + super(StudentTeacherClassifier, self).__init__(**kwargs) + + # create teacher from frozen models + self.teacher = Sequential([frozen_decoder, frozen_encoder]) + + # create student from pieces of unfrozen encoder + # surround pieces with new first layer and attention layer + + first_layer = encoder.layers[:1] + cloned_encoder = clone_model(frozen_encoder) + cloned_layers = cloned_encoder.layers[2:] + for layer in cloned_layers: + layer.trainable = False + + self.student = Sequential(first_layer + cloned_layers) + + # classifier + self.classifier = Sequential([ + GlobalAveragePooling1D(), + Dense(64, activation='relu'), + Dense(classes, activation='softmax', kernel_regularizer='l2') + ]) -## First Layer to convert any channels to 64 ranged [0, 1] -def create_first_layer(channels, expanded_channels=64): - return Sequential([ - SeparableConv1D(expanded_channels, channels, padding='same', use_bias=False), - BatchNormalization(), - Activation('sigmoid'), - Dropout(0.1), - ]) - -## Last Layer to map latent space to custom classes -def create_last_layer(classes): - return Sequential([ - SpatialAttention(classes, 5), - Flatten(), - Dropout(0.1), - Dense(classes, activation='softmax', kernel_regularizer='l2') - ]) \ No newline at end of file + # perceptual and classification losses + self.perceptual_weight = perceptual_weight + self.classify_weight = classify_weight + self.percept_loss = MSE() + self.cce_loss = CategoricalCrossentropy() + + def call(self, inputs): + # predict class + features = self.student(inputs) + output = self.classifier(features) + + # teach the student + reconstruct_features = self.teacher(features) + perceptual_loss = self.perceptual_weight * self.percept_loss(features, reconstruct_features) + self.add_loss(perceptual_loss) + + return output + + def get_loss_function(self): + return lambda y_true, y_pred: self.classify_weight * self.cce_loss(y_true, y_pred) + + def build(self, input_shape): + super(StudentTeacherClassifier, self).build(input_shape) + + def get_lean_model(self): + model = Sequential([ + Input(self.student.input_shape[1:]), + self.student, + self.classifier + ]) + model.compile(optimizer='adam', loss='categorical_crossentropy') + return model diff --git a/model/intent/physionet_decoder.keras b/model/intent/physionet_decoder.keras index 8538f3b..c4c68d9 100644 Binary files a/model/intent/physionet_decoder.keras and b/model/intent/physionet_decoder.keras differ diff --git a/model/intent/physionet_encoder.keras b/model/intent/physionet_encoder.keras index c630157..28ecd4f 100644 Binary files a/model/intent/physionet_encoder.keras and b/model/intent/physionet_encoder.keras differ diff --git a/model/intent/pipeline.py b/model/intent/pipeline.py index 0ed2bc8..c6016c5 100644 --- a/model/intent/pipeline.py +++ b/model/intent/pipeline.py @@ -3,25 +3,31 @@ import numpy as np from scipy import signal -from brainflow.data_filter import DataFilter, DetrendOperations, NoiseTypes, FilterTypes +from brainflow.data_filter import DataFilter, DetrendOperations, NoiseTypes, FilterTypes, WaveletTypes, ThresholdTypes + +from sklearn.preprocessing import StandardScaler as Scaler abs_script_path = os.path.abspath(__file__) abs_script_dir = os.path.dirname(abs_script_path) +scaler = Scaler() ## preprocess and extract features to be shared between train and test def preprocess_data(session_data, sampling_rate): for eeg_chan in range(len(session_data)): - DataFilter.detrend(session_data[eeg_chan], DetrendOperations.LINEAR) + # remove line noise DataFilter.remove_environmental_noise(session_data[eeg_chan], sampling_rate, NoiseTypes.FIFTY_AND_SIXTY.value) - DataFilter.perform_lowpass(session_data[eeg_chan], sampling_rate, 80, 4, FilterTypes.BUTTERWORTH.value, 0) # resample effect mitigation + # bandpass to alpha, beta, gamma, 80 for resample effect mitigation + DataFilter.perform_bandpass(session_data[eeg_chan], sampling_rate, 8, 80, 4, FilterTypes.BUTTERWORTH.value, 0) + # sureshrink adaptive filter + DataFilter.perform_wavelet_denoising(session_data[eeg_chan], WaveletTypes.DB4, 5, threshold=ThresholdTypes.SOFT) return session_data def extract_features(preprocessed_data): features = [] for eeg_row in preprocessed_data: # resample to match physionet dataset - feature = signal.resample(eeg_row, 160) - features.append(feature) + eeg_row = signal.resample(eeg_row, 160) + features.append(eeg_row) return np.stack(features, axis=-1) class Pipeline: diff --git a/model/intent/train.py b/model/intent/train.py index 1483330..4704fb4 100644 --- a/model/intent/train.py +++ b/model/intent/train.py @@ -9,12 +9,12 @@ import keras from keras.models import Sequential -from keras.optimizers import Adam +from keras.optimizers import Adam, AdamW from keras.callbacks import EarlyStopping from keras.utils import to_categorical from sklearn.metrics import classification_report -from model import create_first_layer, create_last_layer +from model import StudentTeacherClassifier from pipeline import preprocess_data, extract_features SAVE_FILENAME = "recorded_eeg" @@ -136,30 +136,24 @@ def process_windows(windows): X_test = np.concatenate([windows_test for _ , windows_test in processed_windows.values()]) y_test = to_categorical(i_test, num_classes=len(processed_windows)) - ## load pretrained encoder and keep it static + ## load pretrained encoder freeze it for use in perceptual loss pretrained_encoder = keras.models.load_model("physionet_encoder.keras") pretrained_encoder.trainable = False + ## load pretrained decoder freeze it for use in perceptual loss + pretrained_decoder = keras.models.load_model("physionet_decoder.keras") + pretrained_decoder.trainable = False - ## create channel expander/normalizer and classification layer + ## get class count from training data classes = len(processed_windows) - user_channels = len(eeg_channels) - encoder_channels = pretrained_encoder.input_shape[-1] - expandalizer = create_first_layer(user_channels, encoder_channels) - classifier = create_last_layer(classes) - ## Create Model - model = Sequential([ - expandalizer, - pretrained_encoder, - classifier - ]) + model = StudentTeacherClassifier(pretrained_encoder, pretrained_decoder, classes) ## Compile the model - model.compile(optimizer=Adam(learning_rate=0.001), loss='categorical_crossentropy') + model.compile(optimizer=AdamW(0.0001), loss=model.get_loss_function()) ## Set up EarlyStopping - early_stopping = EarlyStopping(monitor='val_loss', patience=2*3, restore_best_weights=True, verbose=0) + early_stopping = EarlyStopping(monitor='val_loss', patience=2**3, restore_best_weights=True, verbose=0) ## Train the model batch_size = 128 @@ -173,6 +167,7 @@ def process_windows(windows): ) ## Print out model summary + model = model.get_lean_model() model.summary() ## Save models for realtime use @@ -182,8 +177,13 @@ def process_windows(windows): predictions_prob = model.predict(X_test) predictions = np.argmax(predictions_prob, axis=1) y_test_idxs = np.argmax(y_test, axis=1) + print("Model evaluation:") + model.evaluate(X_test, y_test) print(classification_report(y_test_idxs, predictions)) + # Use the dark background style + plt.style.use('dark_background') + ## Plot history accuracy from model plt.plot(fit_history.history['loss']) plt.plot(fit_history.history['val_loss']) @@ -191,9 +191,44 @@ def process_windows(windows): plt.ylabel('loss') plt.xlabel('epoch') plt.legend(['train', 'val'], loc='upper left') - plt.show() + plt.ylim(0, 1) + plt.savefig('loss.png') + + from sklearn.preprocessing import StandardScaler + from sklearn.manifold import TSNE + import tensorflow as tf + + # Assuming `latent` has shape (samples, timesteps, channels, features) + seq_model = Sequential(model.layers[:-1]) + latent = seq_model(X_test) + + # Step 1: Reshape to 2D by flattening the last three dimensions + samples = latent.shape[0] # Number of samples + # Flatten the timesteps, channels, and features dimensions into a single dimension + latent_flat = tf.reshape(latent, (samples, -1)).numpy() + + # Step 2: Standardize the flattened data + scaler = StandardScaler() + X_scaled = scaler.fit_transform(latent_flat) + + # Step 3: Apply t-SNE + tsne = TSNE(n_components=2, perplexity=30, learning_rate=200, n_iter=1000) + X_tsne = tsne.fit_transform(X_scaled) + + # Step 5: Plot the t-SNE result + plt.figure(figsize=(10, 10)) + scatter = plt.scatter(X_tsne[:, 0], X_tsne[:, 1], c=i_test, cmap='viridis', alpha=0.7) + plt.colorbar(scatter, label='Labels') + plt.title('t-SNE Visualization of Labeled Data') + plt.xlabel('t-SNE Component 1') + plt.ylabel('t-SNE Component 2') + plt.grid(True) + + # Set the scatter plot aspect to be square + plt.axis('square') + plt.savefig('tsne.png') if __name__ == "__main__": main() - \ No newline at end of file + diff --git a/model/intent/tsne.png b/model/intent/tsne.png new file mode 100644 index 0000000..309825e Binary files /dev/null and b/model/intent/tsne.png differ