Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Derived Harmonized Volumes in Age Trends module #167

Merged
merged 1 commit into from
Mar 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion NiBAx/core/model/datamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,15 @@ def __init__(self):
self.ADModel = None


def SetMUSEDictionaries(self, MUSEDictNAMEtoID, MUSEDictIDtoNAME):
def SetMUSEDictionaries(self, MUSEDictNAMEtoID, MUSEDictIDtoNAME,MUSEDictDataFrame):
"""Setter for MUSE dictionary"""
self.MUSEDictNAMEtoID = MUSEDictNAMEtoID
self.MUSEDictIDtoNAME = MUSEDictIDtoNAME
self.MUSEDictDataFrame = MUSEDictDataFrame

def SetDerivedMUSEMap(self,DerivedMUSEMap):
"""Setter for Derived MUSE dictionary"""
self.DerivedMUSEMap = DerivedMUSEMap


def SetDataFilePath(self,p):
Expand All @@ -65,6 +70,14 @@ def GetHarmonizationModelFilePath(self):
def GetMUSEDictionaries(self):
"""Get the MUSE dictionaries to map from ID to name and vice-versa"""
return self.MUSEDictNAMEtoID, self.MUSEDictIDtoNAME

def GetDerivedMUSEMap(self):
"""Get the derived MUSE dictionary to map from SINGLE to DERIVED ROIs"""
return self.DerivedMUSEMap

def GetMUSEDictDataFrame(self):
"""Get the MUSE dictionaries to map from ID to name and vice-versa"""
return self.MUSEDictDataFrame


def SetData(self,d):
Expand Down
82 changes: 65 additions & 17 deletions NiBAx/plugins/harmonization/harmonization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import matplotlib.patches as pat
import numpy as np
import pandas as pd
import re
from NiBAx.core.plotcanvas import PlotCanvas
from NiBAx.core.baseplugin import BasePlugin
from NiBAx.core.gui.SearchableQComboBox import SearchableQComboBox
Expand Down Expand Up @@ -39,6 +40,9 @@ def getUI(self):

def SetupConnections(self):
self.ui.load_harmonization_model_Btn.clicked.connect(lambda: self.OnLoadHarmonizationModelBtnClicked())
if self.datamodel.data is None:
self.ui.load_harmonization_model_Btn.setEnabled(False)
self.ui.Harmonized_Data_Information_Lbl.setText('Data must be loaded before model selection.\nReturn to Load and Save Data tab to select data.')
self.ui.load_other_model_Btn.clicked.connect(lambda: self.OnLoadHarmonizationModelBtnClicked())
self.ui.show_data_Btn.clicked.connect(lambda: self.OnShowDataBtnClicked())
self.ui.apply_model_to_dataset_Btn.clicked.connect(lambda: self.OnApplyModelToDatasetBtnClicked())
Expand All @@ -54,12 +58,14 @@ def SetupConnections(self):
self.ui.show_data_Btn.setStyleSheet("background-color: rgb(230,230,255); color: black")
else:
self.ui.show_data_Btn.setEnabled(False)


def OnLoadHarmonizationModelBtnClicked(self):
filename, _ = QtWidgets.QFileDialog.getOpenFileName(None,
'Open harmonization model file',
QtCore.QDir().homePath(),
"Pickle files (*.pkl.gz *.pkl)")
self.filename = os.path.basename(filename)

if filename == "":
text_1=('Harmonization model not selected')
Expand Down Expand Up @@ -101,13 +107,9 @@ def OnLoadHarmonizationModelBtnClicked(self):
self.ui.stackedWidget.setCurrentIndex(0)

def PopulateROI(self):
#get data column header names
datakeys = self.datamodel.GetColumnHeaderNames()
#construct ROI list to populate comboBox
roiList = ( [x for x in datakeys if x.startswith('MUSE_Volume')])

MUSEDictDataFrame = self.datamodel.GetMUSEDictDataFrame()
_, MUSEDictIDtoNAME = self.datamodel.GetMUSEDictionaries()
roiList = list(set(roiList).intersection(set(datakeys)))
roiList = list(set(self.datamodel.GetColumnHeaderNames()).intersection(set(MUSEDictDataFrame[MUSEDictDataFrame['ROI_LEVEL']=='SINGLE']['ROI_COL'])))
roiList.sort()
roiList = ['(MUSE) ' + list(map(MUSEDictIDtoNAME.get, [k]))[0] if k.startswith('MUSE_') else k for k in roiList]

Expand All @@ -131,14 +133,18 @@ def UpdatePlot(self):
currentROI = self.ui.comboBoxROI.currentText()

# Translate ROI name back to ROI ID
try:
MUSEDictNAMEtoID, _ = self.datamodel.GetMUSEDictionaries()
if currentROI.startswith('(MUSE)'):
currentROI = list(map(MUSEDictNAMEtoID.get, [currentROI[7:]]))[0]
except:
currentROI = 'DLICV'
self.ui.comboBoxROI.setCurrentText('DLICV')
print("Could not translate combo box item. Setting to `DLICV`.")
AllItems = [self.ui.comboBoxROI.itemText(i) for i in range(self.ui.comboBoxROI.count())]
MUSEDictNAMEtoID, _ = self.datamodel.GetMUSEDictionaries()
if currentROI not in AllItems[:-1]:
self.ui.comboBoxROI.blockSignals(True)
self.ui.comboBoxROI.clear()
self.ui.comboBoxROI.blockSignals(False)
self.ui.comboBoxROI.addItems(AllItems[:-1])
currentROI = self.ui.comboBoxROI.itemText(0)
self.ui.comboBoxROI.setCurrentText(currentROI)
print("Invalid input. Setting to %s." % (currentROI))

currentROI = list(map(MUSEDictNAMEtoID.get, [currentROI[7:]]))[0]

#create empty dictionary of plot options
plotOptions = dict()
Expand Down Expand Up @@ -292,7 +298,12 @@ def plotMUSE(self,plotOptions):

def OnAddToDataFrame(self):
print('Saving modified data to pickle file...')
H_ROIs = ['H_'+x for x in self.datamodel.harmonization_model['ROIs']]
MUSEDictDataFrame= self.datamodel.GetMUSEDictDataFrame()
Derived_numbers = list(MUSEDictDataFrame[MUSEDictDataFrame['ROI_LEVEL']=='DERIVED']['ROI_INDEX'])
Derived_MUSE_Volumes = list('MUSE_Volume_' + str(x) for x in Derived_numbers)
ROI_list = list(self.datamodel.harmonization_model['ROIs']) + Derived_MUSE_Volumes
ROI_list.remove('MUSE_Volume_702')
H_ROIs = list('H_' + str(x) for x in ROI_list)
ROIs_ICV_Sex_Residuals = ['RES_ICV_Sex_' + x for x in self.datamodel.harmonization_model['ROIs']]
ROIs_Residuals = ['RES_' + x for x in self.datamodel.harmonization_model['ROIs']]
RAW_Residuals = ['RAW_RES_' + x for x in self.datamodel.harmonization_model['ROIs']]
Expand All @@ -316,6 +327,32 @@ def OnDataChanged(self):
else:
self.ui.show_data_Btn.setEnabled(False)

if self.datamodel.data is None:
self.ui.load_harmonization_model_Btn.setEnabled(False)
self.ui.Harmonized_Data_Information_Lbl.setText('Data must be loaded before model selection.\nReturn to Load and Save Data tab to select data.')
else:
self.ui.load_harmonization_model_Btn.setEnabled(True)
if self.datamodel.harmonization_model is None:
self.ui.Harmonized_Data_Information_Lbl.setText('No harmonization model has been selected')
else:
self.ui.Harmonized_Data_Information_Lbl.setObjectName('correct_label')
self.ui.Harmonized_Data_Information_Lbl.setStyleSheet('QLabel#correct_label {color: black}')
model_text1 = (self.filename +' loaded')
model_text2 = ('SITES in training set: '+ ' '.join([str(elem) for elem in list(self.datamodel.harmonization_model['SITE_labels'])]))
model_text2 = wrap_by_word(model_text2,4)
model_text1 += '\n\n'+model_text2
if 'Covariates' in self.datamodel.harmonization_model:
covariates = self.datamodel.harmonization_model['Covariates']
model_text3 = ('Harmonization Covariates: '+ str(covariates))
model_text1 += '\n'+model_text3
else:
model_text3 = ('Harmonization Covariates Unavailable')
model_text1 += '\n'+model_text3
age_max = self.datamodel.harmonization_model['smooth_model']['bsplines_constructor'].knot_kwds[0]['upper_bound']
age_min = self.datamodel.harmonization_model['smooth_model']['bsplines_constructor'].knot_kwds[0]['lower_bound']
model_text4 = ('Valid Age Range: [' + str(age_min) + ', ' + str(age_max) + ']')
model_text1 += '\n'+model_text4
self.ui.Harmonized_Data_Information_Lbl.setText(model_text1)

def DoHarmonization(self):
print('Running harmonization.')
Expand Down Expand Up @@ -354,9 +391,9 @@ def DoHarmonization(self):
continue

print('Harmonizing '+ site)
gamma_hat_site = np.mean(((Raw_ROIs_Residuals[new_site_is_train,:])/np.dot(np.sqrt(var_pooled),np.ones((1,Raw_ROIs_Residuals[new_site_is_train,:].shape[0]))).T),0)
gamma_hat_site = np.nanmean(((Raw_ROIs_Residuals[new_site_is_train,:])/np.dot(np.sqrt(var_pooled),np.ones((1,Raw_ROIs_Residuals[new_site_is_train,:].shape[0]))).T),0)
gamma_hat_site = gamma_hat_site[:,np.newaxis]
delta_hat_site = pow(np.std(((Raw_ROIs_Residuals[new_site_is_train,:])/np.dot(np.sqrt(var_pooled),np.ones((1,Raw_ROIs_Residuals[new_site_is_train,:].shape[0]))).T),0),2)
delta_hat_site = pow(np.nanstd(((Raw_ROIs_Residuals[new_site_is_train,:])/np.dot(np.sqrt(var_pooled),np.ones((1,Raw_ROIs_Residuals[new_site_is_train,:].shape[0]))).T),0),2)
delta_hat_site = delta_hat_site[:,np.newaxis]

site_gamma = pd.DataFrame(gamma_hat_site.T,columns=gamma_ROIs,index=[site])
Expand All @@ -378,8 +415,19 @@ def DoHarmonization(self):

if 'isTrainMUSEHarmonization' in self.datamodel.data.columns:
muse = pd.concat([self.datamodel.data['isTrainMUSEHarmonization'].copy(), covars, pd.DataFrame(bayes_data, columns=['H_' + s for s in self.datamodel.harmonization_model['ROIs']])],axis=1)
if 'UseForComBatGAMHarmonization' in self.datamodel.data.columns:
muse = pd.concat([self.datamodel.data['UseForComBatGAMHarmonization'].copy(), covars, pd.DataFrame(bayes_data, columns=['H_' + s for s in self.datamodel.harmonization_model['ROIs']])],axis=1)
else:
muse = pd.concat([covars,pd.DataFrame(bayes_data, columns=['H_' + s for s in self.datamodel.harmonization_model['ROIs']])],axis=1)

# harmonize derived volumes
MUSEDictDataFrame= self.datamodel.GetMUSEDictDataFrame()
muse_mappings = self.datamodel.GetDerivedMUSEMap()
for ROI in MUSEDictDataFrame[MUSEDictDataFrame['ROI_LEVEL']=='DERIVED']['ROI_INDEX']:
single_ROIs = muse_mappings.loc[ROI].replace('NaN',np.nan).dropna().astype(np.float)
single_ROIs = ['H_MUSE_Volume_%0d' % x for x in single_ROIs]
muse['H_MUSE_Volume_%d' % ROI] = muse[single_ROIs].sum(axis=1,skipna=False)
muse.drop(columns=['H_MUSE_Volume_702'], inplace=True)

start_index = len(self.datamodel.harmonization_model['SITE_labels'])
sex_icv_effect = np.dot(muse[['Sex','DLICV_baseline']].copy(), self.datamodel.harmonization_model['B_hat'][start_index:(start_index+2),:])
Expand Down
Loading