Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure cluster_helpers.PixelSOMCluster receives the fovs argument #905

Merged
merged 9 commits into from
Feb 14, 2023
2 changes: 1 addition & 1 deletion ark/phenotyping/cell_cluster_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ def train_cell_som(fovs, channels, base_dir, pixel_data_dir, cell_table_path,

# define the cell SOM cluster object
cell_pysom = cluster_helpers.CellSOMCluster(
cluster_counts_size_norm_path, som_weights_path, cluster_count_cols,
cluster_counts_size_norm_path, som_weights_path, fovs, cluster_count_cols,
num_passes=num_passes, xdim=xdim, ydim=ydim, lr_start=lr_start, lr_end=lr_end
)

Expand Down
6 changes: 3 additions & 3 deletions ark/phenotyping/cell_cluster_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,7 @@ def test_cluster_cells(pixel_cluster_prefix):
# error test: no weights assigned to cell pysom object
with pytest.raises(ValueError):
cell_pysom_bad = cluster_helpers.CellSOMCluster(
cluster_counts_size_norm_path, 'bad_path.feather', cluster_cols
cluster_counts_size_norm_path, 'bad_path.feather', [-1], cluster_cols
)

cell_cluster_utils.cluster_cells(base_dir=temp_dir, cell_pysom=cell_pysom_bad)
Expand All @@ -648,7 +648,7 @@ def test_cluster_cells(pixel_cluster_prefix):
feather.write_dataframe(weights, weights_path)

cell_pysom_bad = cluster_helpers.CellSOMCluster(
cluster_counts_size_norm_path, weights_path, cluster_cols
cluster_counts_size_norm_path, weights_path, [-1], cluster_cols
)

cell_cluster_utils.cluster_cells(base_dir=temp_dir, cell_pysom=cell_pysom_bad)
Expand All @@ -662,7 +662,7 @@ def test_cluster_cells(pixel_cluster_prefix):

# define a CellSOMCluster object
cell_pysom = cluster_helpers.CellSOMCluster(
cluster_counts_size_norm_path, cell_som_weights_path, cluster_cols
cluster_counts_size_norm_path, cell_som_weights_path, [-1], cluster_cols
)

# error test: bad cluster_col provided
Expand Down
22 changes: 18 additions & 4 deletions ark/phenotyping/cluster_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def generate_som_clusters(self, external_data: pd.DataFrame) -> np.ndarray:

class PixelSOMCluster(PixieSOMCluster):
def __init__(self, pixel_subset_folder: pathlib.Path, norm_vals_path: pathlib.Path,
weights_path: pathlib.Path, columns: List[str],
weights_path: pathlib.Path, fovs: List[str], columns: List[str],
num_passes: int = 1, xdim: int = 10, ydim: int = 10,
lr_start: float = 0.05, lr_end: float = 0.01):
"""Creates a pixel SOM cluster object derived from the abstract PixieSOMCluster
Expand All @@ -128,6 +128,8 @@ def __init__(self, pixel_subset_folder: pathlib.Path, norm_vals_path: pathlib.Pa
The name of the feather file containing the normalization values.
weights_path (pathlib.Path):
The path to save the weights to.
fovs (List[str]):
The list of FOVs to subset the data on.
columns (List[str]):
The list of columns to subset the data on.
num_passes (int):
Expand All @@ -151,10 +153,14 @@ def __init__(self, pixel_subset_folder: pathlib.Path, norm_vals_path: pathlib.Pa
# load the normalization values in
self.norm_data = feather.read_dataframe(norm_vals_path)

# define the fovs used
self.fovs = fovs

# list all the files in pixel_subset_folder and load them to train_data
fov_files = list_files(pixel_subset_folder, substrs='.feather')
self.train_data = pd.concat(
[feather.read_dataframe(os.path.join(pixel_subset_folder, fov)) for fov in fov_files]
[feather.read_dataframe(os.path.join(pixel_subset_folder, fov)) for fov in fov_files
if os.path.splitext(fov)[0] in fovs]
)

# we can just normalize train_data now since that's what we'll be training on
Expand Down Expand Up @@ -225,15 +231,17 @@ def assign_som_clusters(self, external_data: pd.DataFrame) -> pd.DataFrame:

class CellSOMCluster(PixieSOMCluster):
def __init__(self, cell_data_path: pathlib.Path, weights_path: pathlib.Path,
columns: List[str], num_passes: int = 1, xdim: int = 10, ydim: int = 10,
lr_start: float = 0.05, lr_end: float = 0.01):
fovs: List[str], columns: List[str], num_passes: int = 1,
xdim: int = 10, ydim: int = 10, lr_start: float = 0.05, lr_end: float = 0.01):
"""Creates a cell SOM cluster object derived from the abstract PixieSOMCluster

Args:
cell_data_path (pathlib.Path):
The name of the cell dataset to use for training
weights_path (pathlib.Path):
The path to save the weights to.
fovs (List[str]):
The list of FOVs to subset the data on.
columns (List[str]):
The list of columns to subset the data on.
num_passes (int):
Expand All @@ -258,6 +266,12 @@ def __init__(self, cell_data_path: pathlib.Path, weights_path: pathlib.Path,
# load the cell data in
self.cell_data = feather.read_dataframe(cell_data_path)

# define the fovs used
self.fovs = fovs

# subset cell_data on just the FOVs specified
self.cell_data = self.cell_data[self.cell_data['fov'].isin(self.fovs)]

# since cell_data is the only dataset, we can just normalize it immediately
self.normalize_data()

Expand Down
8 changes: 4 additions & 4 deletions ark/phenotyping/cluster_helpers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,14 @@ def pixel_pyflowsom_object(pixel_som_base_dir) -> Iterator[
# define a PixelSOMCluster object with weights
pixel_som_with_weights = PixelSOMCluster(
pixel_subset_folder=pixel_sub_path, norm_vals_path=norm_vals_path,
weights_path=weights_path, columns=channels, xdim=20, ydim=10
weights_path=weights_path, fovs=fovs, columns=channels, xdim=20, ydim=10
)

# define a PixelSOMCluster object without weights
pixel_som_sans_weights = PixelSOMCluster(
pixel_subset_folder=pixel_sub_path, norm_vals_path=norm_vals_path,
weights_path=pixel_som_base_dir / 'weights_new.feather',
columns=channels, xdim=20, ydim=10
fovs=fovs[:2], columns=channels, xdim=20, ydim=10
)

yield pixel_som_with_weights, pixel_som_sans_weights
Expand Down Expand Up @@ -154,13 +154,13 @@ def cell_pyflowsom_object(cell_som_base_dir) -> Iterator[
# define a CellSOMCluster object with weights
cell_som_with_weights = CellSOMCluster(
cell_data_path=cell_data_path, weights_path=weights_path,
columns=count_cols, xdim=20, ydim=10
fovs=['fov0', 'fov1'], columns=count_cols, xdim=20, ydim=10
)

# define a CellSOMCluster object without weights
cell_som_sans_weights = CellSOMCluster(
cell_data_path=cell_data_path, weights_path=cell_som_base_dir / 'weights_new.feather',
columns=count_cols, xdim=20, ydim=10
fovs=['fov0'], columns=count_cols, xdim=20, ydim=10
)

yield cell_som_with_weights, cell_som_sans_weights
Expand Down
8 changes: 7 additions & 1 deletion ark/phenotyping/pixel_cluster_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -929,7 +929,7 @@ def train_pixel_som(fovs, channels, base_dir,

# define the pixel SOM cluster object
pixel_pysom = cluster_helpers.PixelSOMCluster(
subsetted_path, norm_vals_path, som_weights_path, channels,
subsetted_path, norm_vals_path, som_weights_path, fovs, channels,
num_passes=num_passes, xdim=xdim, ydim=ydim, lr_start=lr_start, lr_end=lr_end
)

Expand Down Expand Up @@ -1052,6 +1052,9 @@ def cluster_pixels(fovs, channels, base_dir, pixel_pysom, data_dir='pixel_mat_da
# only assign SOM clusters to FOVs that don't already have them
fovs_list = find_fovs_missing_col(base_dir, data_dir, 'pixel_som_cluster')

# make sure fovs_list only contain fovs that exist in the master fovs list specified
fovs_list = list(set(fovs_list).intersection(fovs))

# if there are no FOVs left without SOM labels don't run function
if len(fovs_list) == 0:
print("There are no more FOVs to assign SOM labels to, skipping")
Expand Down Expand Up @@ -1247,6 +1250,9 @@ def pixel_consensus_cluster(fovs, channels, base_dir, max_k=20, cap=3,
# only assign meta clusters to FOVs that don't already have them
fovs_list = find_fovs_missing_col(base_dir, data_dir, 'pixel_meta_cluster')

# make sure fovs_list only contain fovs that exist in the master fovs list specified
fovs_list = list(set(fovs_list).intersection(fovs))

# if there are no FOVs left without meta labels don't run function
if len(fovs_list) == 0:
print("There are no more FOVs to assign meta labels to, skipping")
Expand Down
13 changes: 7 additions & 6 deletions ark/phenotyping/pixel_cluster_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1362,7 +1362,7 @@ def test_run_pixel_som_assignment():
# define example PixelSOMCluster object
sample_pixel_cc = cluster_helpers.PixelSOMCluster(
os.path.join(temp_dir, 'pixel_mat_data'), sample_norm_vals_path,
sample_som_weights_path, chans
sample_som_weights_path, fovs, chans
)

fov_status = pixel_cluster_utils.run_pixel_som_assignment(
Expand Down Expand Up @@ -1468,13 +1468,14 @@ def test_cluster_pixels_base(multiprocess):
with pytest.raises(ValueError):
pixel_pysom_bad = cluster_helpers.PixelSOMCluster(
os.path.join(temp_dir, 'pixel_mat_data'), norm_vals_path,
'bad_path.feather', chan_list
'bad_path.feather', fovs, chan_list
)
pixel_cluster_utils.cluster_pixels(fovs, chan_list, temp_dir, pixel_pysom_bad)

# create a sample PixelSOMCluster object
pixel_pysom = cluster_helpers.PixelSOMCluster(
os.path.join(temp_dir, 'pixel_mat_data'), norm_vals_path, som_weights_path, chan_list
os.path.join(temp_dir, 'pixel_mat_data'), norm_vals_path, som_weights_path,
fovs, chan_list
)

# run SOM cluster assignment
Expand Down Expand Up @@ -1505,7 +1506,7 @@ def test_cluster_pixels_corrupt(multiprocess, capsys):

# create a sample PixelSOMCluster object
pixel_pysom = cluster_helpers.PixelSOMCluster(
os.path.join(temp_dir, 'pixel_mat_data'), norm_vals_path, som_weights_path, chans
os.path.join(temp_dir, 'pixel_mat_data'), norm_vals_path, som_weights_path, fovs, chans
)

# corrupt a fov for this test
Expand Down Expand Up @@ -1566,13 +1567,13 @@ def test_generate_som_avg_files(capsys):
# error test: weights not assigned to PixelSOMCluster object
with pytest.raises(ValueError):
pixel_pysom_bad = cluster_helpers.PixelSOMCluster(
pixel_data_path, norm_vals_path, 'bad_path.feather', chan_list
pixel_data_path, norm_vals_path, 'bad_path.feather', fovs, chan_list
)
pixel_cluster_utils.generate_som_avg_files(fovs, chan_list, temp_dir, pixel_pysom_bad)

# define an example PixelSOMCluster object
pixel_pysom = cluster_helpers.PixelSOMCluster(
pixel_data_path, norm_vals_path, weights_path, chan_list
pixel_data_path, norm_vals_path, weights_path, fovs, chan_list
)

# test base generation with all subsetted FOVs
Expand Down
2 changes: 1 addition & 1 deletion ark/utils/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def plot_pixel_cell_cluster_overlay(img_xr, fovs, cluster_id_to_name_path, metac
plt.axis('off')

# remove the gridlines
plt.grid(b=None)
plt.grid(visible=False)

# define the colorbar with annotations
cax = fig.add_axes([0.9, 0.1, 0.01, 0.8])
Expand Down