Skip to content

Commit

Permalink
Add a hidden parameter to overwrite certain portions of the pipeline (#…
Browse files Browse the repository at this point in the history
…914)

* Add overwrite functionality for Pixie pixels and for cell SOM training

* Test the case where cell_som_cluster already exists

* Pin pyFlowSOM at 0.1.13 to ensure deterministic flag doesn't fail yet

* Fix a merge conflict

* Make sure we only drop cell_size if it exists

* Add overwrite functionality to averaging functions in cell and pixel clustering

* Force another push to rerun jobs

* Handle case where data and subset folders may not always match up pretty

* Reset cell index on FOVs subset

* Add test to handle case where FOV written to subset and not data
  • Loading branch information
alex-l-kong authored Mar 2, 2023
1 parent 97e6baa commit 0b093db
Show file tree
Hide file tree
Showing 6 changed files with 296 additions and 43 deletions.
52 changes: 40 additions & 12 deletions src/ark/phenotyping/cell_cluster_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,8 @@ def create_c2pc_data(fovs, pixel_data_path, cell_table_path,

def train_cell_som(fovs, base_dir, cell_table_path, cell_som_cluster_cols,
cell_som_input_data, som_weights_name='cell_som_weights.feather',
xdim=10, ydim=10, lr_start=0.05, lr_end=0.01, num_passes=1, seed=42):
xdim=10, ydim=10, lr_start=0.05, lr_end=0.01, num_passes=1, seed=42,
overwrite=False):
"""Run the SOM training on the expression columns specified in `cell_som_cluster_cols`.
Saves the SOM weights to `base_dir/som_weights_name`.
Expand Down Expand Up @@ -431,6 +432,8 @@ def train_cell_som(fovs, base_dir, cell_table_path, cell_som_cluster_cols,
The number of training passes to make through the dataset
seed (int):
The random seed to use for training the SOM
overwrite (bool):
If set, force retrains the SOM and overwrites the weights
Returns:
cluster_helpers.CellSOMCluster:
Expand Down Expand Up @@ -459,7 +462,7 @@ def train_cell_som(fovs, base_dir, cell_table_path, cell_som_cluster_cols,
# train the SOM weights
# NOTE: seed has to be set in cyFlowSOM.pyx, done by passing flag in PixieSOMCluster
print("Training SOM")
cell_pysom.train_som()
cell_pysom.train_som(overwrite=overwrite)

return cell_pysom

Expand Down Expand Up @@ -492,6 +495,11 @@ def cluster_cells(base_dir, cell_pysom, cell_som_cluster_cols):
cols_to_drop.append('cell_size')

# ensure the weights columns are valid indexes, do so by ensuring
# the cluster_counts_norm and weights columns are the same
# minus the metadata columns (and possibly cluster col) that appear in cluster_counts_norm
if 'cell_som_cluster' in cell_pysom.cell_data.columns.values:
cols_to_drop.append('cell_som_cluster')

# the cell_som_input_data and weights columns are the same
# minus the metadata columns that appear in cluster_counts_norm
cell_som_input_data = cell_pysom.cell_data.drop(
Expand All @@ -513,7 +521,7 @@ def cluster_cells(base_dir, cell_pysom, cell_som_cluster_cols):


def generate_som_avg_files(base_dir, cell_som_input_data, cell_som_cluster_cols,
cell_som_expr_col_avg_name):
cell_som_expr_col_avg_name, overwrite=False):
"""Computes and saves the average expression of all `cell_som_cluster_cols`
across cell SOM clusters.
Expand All @@ -527,6 +535,8 @@ def generate_som_avg_files(base_dir, cell_som_input_data, cell_som_cluster_cols,
cell_som_expr_col_avg_name (str):
The name of the file to write the average expression per column
across cell SOM clusters
overwrite (bool):
If set, regenerate the averages of `cell_som_cluster_columns` for SOM clusters
"""

# define the paths to the data
Expand All @@ -536,10 +546,15 @@ def generate_som_avg_files(base_dir, cell_som_input_data, cell_som_cluster_cols,
if 'cell_som_cluster' not in cell_som_input_data.columns.values:
raise ValueError('cell_som_input_data does not have SOM labels assigned')

# if the channel SOM average file already exists, skip
# if the channel SOM average file already exists and the overwrite flag isn't set, skip
if os.path.exists(som_expr_col_avg_path):
print("Already generated average expression file for each cell SOM column, skipping")
return
if not overwrite:
print("Already generated average expression file for each cell SOM column, skipping")
return

print(
"Overwrite flag set, regenerating average expression file for cell SOM clusters"
)

# compute the average column expression values per cell SOM cluster
print("Computing the average value of each training column specified per cell SOM cluster")
Expand Down Expand Up @@ -630,7 +645,7 @@ def cell_consensus_cluster(base_dir, cell_som_cluster_cols, cell_som_input_data,
def generate_meta_avg_files(base_dir, cell_cc, cell_som_cluster_cols,
cell_som_input_data,
cell_som_expr_col_avg_name,
cell_meta_expr_col_avg_name):
cell_meta_expr_col_avg_name, overwrite=False):
"""Computes and saves the average cluster column expression across pixel meta clusters.
Assigns meta cluster labels to the data stored in `cell_som_expr_col_avg_name`.
Expand All @@ -649,6 +664,8 @@ def generate_meta_avg_files(base_dir, cell_cc, cell_som_cluster_cols,
Used to run consensus clustering on.
cell_meta_expr_col_avg_name (str):
Same as above except for cell meta clusters
overwrite (bool):
If set, regenerate the averages of `cell_som_cluster_cols` per meta cluster
"""
# define the paths to the data
som_expr_col_avg_path = os.path.join(base_dir, cell_som_expr_col_avg_name)
Expand All @@ -663,8 +680,13 @@ def generate_meta_avg_files(base_dir, cell_cc, cell_som_cluster_cols,

# if the column average file for cell meta clusters already exists, skip
if os.path.exists(meta_expr_col_avg_path):
print("Already generated average expression file for cell meta clusters, skipping")
return
if not overwrite:
print("Already generated average expression file for cell meta clusters, skipping")
return

print(
"Overwrite flag set, regenerating average expression file for cell meta clusters"
)

# compute the average value of each expression column per cell meta cluster
print("Computing the average value of each training column specified per cell meta cluster")
Expand Down Expand Up @@ -704,7 +726,8 @@ def generate_meta_avg_files(base_dir, cell_cc, cell_som_cluster_cols,
def generate_wc_avg_files(fovs, channels, base_dir, cell_cc, cell_som_input_data,
weighted_cell_channel_name='weighted_cell_channel.feather',
cell_som_cluster_channel_avg_name='cell_som_cluster_channel_avg.csv',
cell_meta_cluster_channel_avg_name='cell_meta_cluster_channel_avg.csv'):
cell_meta_cluster_channel_avg_name='cell_meta_cluster_channel_avg.csv',
overwrite=False):
"""Generate the weighted channel average files per cell SOM and meta clusters.
When running cell clustering with pixel clusters generated from Pixie, the counts of each
Expand Down Expand Up @@ -732,6 +755,8 @@ def generate_wc_avg_files(fovs, channels, base_dir, cell_cc, cell_som_input_data
per cell SOM cluster
cell_meta_cluster_channel_avg_name (str):
Same as above except for cell meta clusters
overwrite (bool):
If set, regenerate average weighted channel expression for SOM and meta clusters
"""
# define the paths to the data
weighted_channel_path = os.path.join(base_dir, weighted_cell_channel_name)
Expand All @@ -744,8 +769,11 @@ def generate_wc_avg_files(fovs, channels, base_dir, cell_cc, cell_som_input_data
# if the weighted channel average files exist, skip
if os.path.exists(som_cluster_channel_avg_path) and \
os.path.exists(meta_cluster_channel_avg_path):
print("Already generated average weighted channel expression files, skipping")
return
if not overwrite:
print("Already generated average weighted channel expression files, skipping")
return

print("Overwrite flag set, regenerating average weighted channel expression files")

print("Compute average weighted channel expression across cell SOM clusters")
cell_som_cluster_channel_avg = compute_cell_cluster_weighted_channel_avg(
Expand Down
30 changes: 22 additions & 8 deletions src/ark/phenotyping/cluster_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,12 +199,17 @@ def normalize_data(self, external_data: pd.DataFrame) -> pd.DataFrame:

return external_data_norm

def train_som(self):
def train_som(self, overwrite=False):
"""Trains the SOM using `train_data`
"""
if self.weights is not None:
# do not train SOM if weights already exist and the same markers used to train
overwrite (bool):
If set, force retrains the SOM and overwrites the weights
"""
# if overwrite flag set, retrain SOM regardless of state
if overwrite:
warnings.warn('Overwrite flag set, retraining SOM')
# otherwise, do not train SOM if weights already exist and the same markers used to train
elif self.weights is not None:
if set(self.weights.columns.values) == set(self.columns):
warnings.warn('Pixel SOM already trained on specified markers')
return
Expand Down Expand Up @@ -275,7 +280,9 @@ def __init__(self, cell_data: pd.DataFrame, weights_path: pathlib.Path,
self.fovs = fovs

# subset cell_data on just the FOVs specified
self.cell_data = self.cell_data[self.cell_data['fov'].isin(self.fovs)]
self.cell_data = self.cell_data[
self.cell_data['fov'].isin(self.fovs)
].reset_index(drop=True)

# since cell_data is the only dataset, we can just normalize it immediately
self.normalize_data()
Expand All @@ -299,11 +306,18 @@ def normalize_data(self):
# assign back to cell_data
self.cell_data[self.columns] = cell_data_sub

def train_som(self):
def train_som(self, overwrite=False):
"""Trains the SOM using `cell_data`
overwrite (bool):
If set, force retrains the SOM and overwrites the weights
"""
if self.weights is not None:
# do not train SOM if weights already exist and the same columns used to train
# if overwrite flag set, retrain SOM regardless of state
if overwrite:
warnings.warn('Overwrite flag set, retraining SOM')

# otherwise, do not train SOM if weights already exist and the same columns used to train
elif self.weights is not None:
if set(self.weights.columns.values) == set(self.columns):
warnings.warn('Cell SOM already trained on specified columns')
return
Expand Down
75 changes: 56 additions & 19 deletions src/ark/phenotyping/pixel_cluster_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,16 +691,20 @@ def create_pixel_matrix(fovs, channels, base_dir, tiff_dir, seg_dir,
# create variable for storing 99.9% values
quant_dat = pd.DataFrame()

# find all the FOV files in the subsetted directory
# find all the FOV files in the full data and subsetted directories
# NOTE: this handles the case where the data file was written, but not the subset file
fovs_sub = io_utils.list_files(os.path.join(base_dir, subset_dir), substrs='.feather')
fovs_data = io_utils.list_files(os.path.join(base_dir, data_dir), substrs='.feather')

# intersect the two fovs lists together (if a FOV appears in one but not the other, regenerate)
fovs_full = list(set(fovs_sub).intersection(fovs_data))

# trim the .feather suffix from the fovs in the subsetted directory
fovs_sub = io_utils.remove_file_extensions(fovs_sub)
fovs_full = io_utils.remove_file_extensions(fovs_full)

# define the list of FOVs for preprocessing
# NOTE: if an existing FOV is already corrupted, future steps will discard it
fovs_list = list(set(fovs).difference(set(fovs_sub)))
fovs_list = list(set(fovs).difference(set(fovs_full)))

# if there are no FOVs left to preprocess don't run function
if len(fovs_list) == 0:
Expand Down Expand Up @@ -874,7 +878,8 @@ def train_pixel_som(fovs, channels, base_dir,
subset_dir='pixel_mat_subsetted',
norm_vals_name='post_rowsum_chan_norm.feather',
som_weights_name='pixel_som_weights.feather', xdim=10, ydim=10,
lr_start=0.05, lr_end=0.01, num_passes=1, seed=42):
lr_start=0.05, lr_end=0.01, num_passes=1, seed=42,
overwrite=False):
"""Run the SOM training on the subsetted pixel data.
Saves SOM weights to `base_dir/som_weights_name`.
Expand Down Expand Up @@ -904,6 +909,8 @@ def train_pixel_som(fovs, channels, base_dir,
The number of training passes to make through the dataset
seed (int):
The random seed to use for training the SOM
overwrite (bool):
If set, force retrains the SOM and overwrites the weights
Returns:
cluster_helpers.PixelSOMCluster:
Expand Down Expand Up @@ -939,7 +946,7 @@ def train_pixel_som(fovs, channels, base_dir,
# train the SOM weights
# NOTE: seed has to be set in cyFlowSOM.pyx, done by passing flag in PixieSOMCluster
print("Training SOM")
pixel_pysom.train_som()
pixel_pysom.train_som(overwrite=overwrite)

return pixel_pysom

Expand Down Expand Up @@ -981,7 +988,7 @@ def run_pixel_som_assignment(pixel_data_path, pixel_pysom_obj, fov):


def cluster_pixels(fovs, channels, base_dir, pixel_pysom, data_dir='pixel_mat_data',
multiprocess=False, batch_size=5):
multiprocess=False, batch_size=5, overwrite=False):
"""Uses trained SOM weights to assign cluster labels on full pixel data.
Saves data with cluster labels to `data_dir`.
Expand All @@ -1001,6 +1008,8 @@ def cluster_pixels(fovs, channels, base_dir, pixel_pysom, data_dir='pixel_mat_da
Whether to use multiprocessing or not
batch_size (int):
The number of FOVs to process in parallel, ignored if `multiprocess` is `False`
overwrite (bool):
If set, force overwrite the SOM labels in all the FOVs
"""

# define the paths to the data
Expand Down Expand Up @@ -1052,8 +1061,16 @@ def cluster_pixels(fovs, channels, base_dir, pixel_pysom, data_dir='pixel_mat_da
pixel_data_columns=sample_fov.columns.values
)

# only assign SOM clusters to FOVs that don't already have them
fovs_list = find_fovs_missing_col(base_dir, data_dir, 'pixel_som_cluster')
# if overwrite flag set, run on all FOVs in data_dir
if overwrite:
print('Overwrite flag set, reassigning SOM cluster labels to all FOVs')
os.mkdir(data_path + '_temp')
fovs_list = io_utils.remove_file_extensions(
io_utils.list_files(data_path, substrs='.feather')
)
# otherwise, only assign SOM clusters to FOVs that don't already have them
else:
fovs_list = find_fovs_missing_col(base_dir, data_dir, 'pixel_som_cluster')

# make sure fovs_list only contain fovs that exist in the master fovs list specified
fovs_list = list(set(fovs_list).intersection(fovs))
Expand Down Expand Up @@ -1116,7 +1133,7 @@ def cluster_pixels(fovs, channels, base_dir, pixel_pysom, data_dir='pixel_mat_da

def generate_som_avg_files(fovs, channels, base_dir, pixel_pysom, data_dir='pixel_data_dir',
pc_chan_avg_som_cluster_name='pixel_channel_avg_som_cluster.csv',
num_fovs_subset=100, seed=42):
num_fovs_subset=100, seed=42, overwrite=False):
"""Computes and saves the average channel expression across pixel SOM clusters.
Args:
Expand All @@ -1136,6 +1153,8 @@ def generate_som_avg_files(fovs, channels, base_dir, pixel_pysom, data_dir='pixe
The number of FOVs to subset on for SOM cluster channel averaging
seed (int):
The random seed to set for subsetting FOVs
overwrite (bool):
If set, force overwrite the existing average channel expression file if it exists
"""

# define the paths to the data
Expand All @@ -1145,10 +1164,13 @@ def generate_som_avg_files(fovs, channels, base_dir, pixel_pysom, data_dir='pixe
if pixel_pysom.weights is None:
raise ValueError("Using untrained pixel_pysom object, please invoke train_som first")

# if the channel SOM average file already exists, skip
# if the channel SOM average file already exists and the overwrite flag isn't set, skip
if os.path.exists(som_cluster_avg_path):
print("Already generated SOM cluster channel average file, skipping")
return
if not overwrite:
print("Already generated SOM cluster channel average file, skipping")
return

print("Overwrite flag set, regenerating SOM cluster channel average file")

# compute average channel expression for each pixel SOM cluster
# and the number of pixels per SOM cluster
Expand Down Expand Up @@ -1211,7 +1233,7 @@ def run_pixel_consensus_assignment(pixel_data_path, pixel_cc_obj, fov):
def pixel_consensus_cluster(fovs, channels, base_dir, max_k=20, cap=3,
data_dir='pixel_mat_data',
pc_chan_avg_som_cluster_name='pixel_channel_avg_som_cluster.csv',
multiprocess=False, batch_size=5, seed=42):
multiprocess=False, batch_size=5, seed=42, overwrite=False):
"""Run consensus clustering algorithm on pixel-level summed data across channels
Saves data with consensus cluster labels to `data_dir`.
Expand All @@ -1237,6 +1259,8 @@ def pixel_consensus_cluster(fovs, channels, base_dir, max_k=20, cap=3,
The number of FOVs to process in parallel, ignored if `multiprocess` is `False`
seed (int):
The random seed to set for consensus clustering
overwrite (bool):
If set, force overwrites the meta labels in all the FOVs
Returns:
cluster_helpers.PixieConsensusCluster:
Expand All @@ -1250,8 +1274,16 @@ def pixel_consensus_cluster(fovs, channels, base_dir, max_k=20, cap=3,
# path validation
io_utils.validate_paths([pixel_data_path, som_cluster_avg_path])

# only assign meta clusters to FOVs that don't already have them
fovs_list = find_fovs_missing_col(base_dir, data_dir, 'pixel_meta_cluster')
# if overwrite flag set, run on all FOVs in data_dir
if overwrite:
print('Overwrite flag set, reassigning meta cluster labels to all FOVs')
os.mkdir(pixel_data_path + '_temp')
fovs_list = io_utils.remove_file_extensions(
io_utils.list_files(pixel_data_path, substrs='.feather')
)
# otherwise, only assign meta clusters to FOVs that don't already have them
else:
fovs_list = find_fovs_missing_col(base_dir, data_dir, 'pixel_meta_cluster')

# make sure fovs_list only contain fovs that exist in the master fovs list specified
fovs_list = list(set(fovs_list).intersection(fovs))
Expand Down Expand Up @@ -1337,7 +1369,7 @@ def pixel_consensus_cluster(fovs, channels, base_dir, max_k=20, cap=3,
def generate_meta_avg_files(fovs, channels, base_dir, pixel_cc, data_dir='pixel_mat_data',
pc_chan_avg_som_cluster_name='pixel_channel_avg_som_cluster.csv',
pc_chan_avg_meta_cluster_name='pixel_channel_avg_meta_cluster.csv',
num_fovs_subset=100, seed=42):
num_fovs_subset=100, seed=42, overwrite=False):
"""Computes and saves the average channel expression across pixel meta clusters.
Assigns meta cluster labels to the data stored in `pc_chan_avg_som_cluster_name`.
Expand All @@ -1361,6 +1393,8 @@ def generate_meta_avg_files(fovs, channels, base_dir, pixel_cc, data_dir='pixel_
The number of FOVs to subset on for meta cluster channel averaging
seed (int):
The random seed to use for subsetting FOVs
overwrite (bool):
If set, force overwrites the existing average channel expression file if it exists
"""

# define the paths to the data
Expand All @@ -1370,10 +1404,13 @@ def generate_meta_avg_files(fovs, channels, base_dir, pixel_cc, data_dir='pixel_
# path validation
io_utils.validate_paths([som_cluster_avg_path])

# if the channel meta average file already exists, skip
# if the channel meta average file already exists and the overwrite flag isn't set, skip
if os.path.exists(meta_cluster_avg_path):
print("Already generated meta cluster channel average file, skipping")
return
if not overwrite:
print("Already generated meta cluster channel average file, skipping")
return

print("Overwrite flag set, regenerating meta cluster channel average file")

# compute average channel expression for each pixel meta cluster
# and the number of pixels per meta cluster
Expand Down
Loading

0 comments on commit 0b093db

Please sign in to comment.