Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix color masking in Pixie #1127

Merged
merged 11 commits into from
Mar 21, 2024
30 changes: 28 additions & 2 deletions src/ark/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,11 @@ def __init__(
].copy()

# Add a cluster_id_column to the column in case the cluster_column is
# non-numeric (i.e. string)
# non-numeric (i.e. string), index in ascending order of cell_meta_cluster
cluster_name_id = pd.DataFrame(
{self.cluster_column: mapping_data[self.cluster_column].unique()})
cluster_name_id.sort_values(by=f'{self.cluster_column}', inplace=True)
cluster_name_id.reset_index(drop=True, inplace=True)

cluster_name_id[self.cluster_id_column] = (cluster_name_id.index + 1).astype(np.int32)

Expand Down Expand Up @@ -537,9 +539,16 @@ def generate_pixel_cluster_mask(fov, base_dir, tiff_dir, chan_file_path,
# convert to 1D indexing
coordinates = x_coords * img_data.shape[1] + y_coords

# get the cooresponding cluster labels for each pixel
# get the corresponding cluster labels for each pixel
cluster_labels = list(fov_data[pixel_cluster_col])

# relabel clusters with sequential integers (cluster_id)
unique_clusters = list(np.unique(cluster_labels)) # returns sorted meta cluster numbers
cluster_ids = list(range(1, len(unique_clusters) + 1))
id_mapping = {meta_cluster: cluster_id
for meta_cluster, cluster_id in zip(unique_clusters, cluster_ids)}
cluster_labels = [id_mapping[label] for label in cluster_labels]

camisowers marked this conversation as resolved.
Show resolved Hide resolved
# assign each coordinate in pixel_cluster_mask to its respective cluster label
img_subset = img_data.ravel()
img_subset[coordinates] = cluster_labels
Expand All @@ -554,6 +563,7 @@ def generate_and_save_pixel_cluster_masks(fovs: List[str],
tiff_dir: Union[pathlib.Path, str],
chan_file: Union[pathlib.Path, str],
pixel_data_dir: Union[pathlib.Path, str],
cluster_id_to_name_path: Union[pathlib.Path, str],
pixel_cluster_col: str = 'pixel_meta_cluster',
sub_dir: str = None,
name_suffix: str = ''):
Expand All @@ -574,6 +584,9 @@ def generate_and_save_pixel_cluster_masks(fovs: List[str],
pixel_data_dir (Union[pathlib.Path, str]):
The path to the data with full pixel data.
This data should also have the SOM and meta cluster labels appended.
cluster_id_to_name_path (Union[str, pathlib.Path]): A path to a CSV identifying the
pixel cluster to manually-defined name mapping this is output by the remapping
visualization found in `metacluster_remap_gui`.
pixel_cluster_col (str, optional):
The path to the data with full pixel data.
This data should also have the SOM and meta cluster labels appended.
Expand All @@ -585,6 +598,19 @@ def generate_and_save_pixel_cluster_masks(fovs: List[str],
name_suffix (str, optional):
Specify what to append at the end of every pixel mask. Defaults to `''`.
"""
# read in gui cluster mapping file and save cluster_id created in generate_pixel_cluster_mask
gui_map = pd.read_csv(cluster_id_to_name_path)
cluster_map = gui_map.copy()[[pixel_cluster_col]]

cluster_map = cluster_map.drop_duplicates().sort_values(by=[pixel_cluster_col])
cluster_map["cluster_id"] = list(range(1, len(cluster_map) + 1))

# drop the cluster_id column from gui_map if it already exists, otherwise do nothing
gui_map = gui_map.drop(columns="cluster_id", errors="ignore")

# add a cluster_id column corresponding to the new mask integers
updated_cluster_map = gui_map.merge(cluster_map, on=[pixel_cluster_col], how="left")
updated_cluster_map.to_csv(cluster_id_to_name_path, index=False)

# create the pixel cluster masks across each fov
with tqdm(total=len(fovs), desc="Pixel Cluster Mask Generation", unit="FOVs") \
Expand Down
8 changes: 4 additions & 4 deletions src/ark/utils/metacluster_remap_gui/colormap_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,13 @@ def generate_meta_cluster_colormap_dict(meta_cluster_remap_path, cmap, cluster_t
remapping = pd.read_csv(meta_cluster_remap_path)

# assert the correct columns are contained
misc_utils.verify_same_elements(
remapping_cols=remapping.columns.values,
misc_utils.verify_in_list(
required_cols=[
f'{cluster_type}_som_cluster',
f'{cluster_type}_meta_cluster',
f'{cluster_type}_meta_cluster_rename'
]
f'{cluster_type}_meta_cluster_rename',
],
remapping_cols=remapping.columns.values
)

# define the raw meta cluster colormap
Expand Down
63 changes: 27 additions & 36 deletions src/ark/utils/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ class MetaclusterColormap:
background_color: Tuple[float, ...] = field(init=False)
metacluster_id_to_name: pd.DataFrame = field(init=False)
mc_colors: np.ndarray = field(init=False)
metacluster_to_index: Dict = field(init=False)
cmap: colors.ListedColormap = field(init=False)
norm: colors.BoundaryNorm = field(init=False)

Expand Down Expand Up @@ -99,6 +98,7 @@ def _metacluster_cmap_generator(self) -> None:
f"{self.cluster_type}_som_cluster",
f"{self.cluster_type}_meta_cluster",
f"{self.cluster_type}_meta_cluster_rename",
f"cluster_id",
camisowers marked this conversation as resolved.
Show resolved Hide resolved
],
cluster_mapping_cols=cluster_id_to_name.columns.values
)
Expand All @@ -112,7 +112,7 @@ def _metacluster_cmap_generator(self) -> None:
metacluster_id_to_name[f"{self.cluster_type}_meta_cluster"].max() + 1)

# Extract unique pairs of (metacluster-ID, name)
# Set the unassigned cluster ID to be the max ID + 1
# Set the unassigned meta cluster to be the max ID + 1
# Set 0 as the Empty value
metacluster_id_to_name: pd.DataFrame = pd.concat(
[
Expand All @@ -126,9 +126,6 @@ def _metacluster_cmap_generator(self) -> None:
]
)

# sort by metacluster id ascending, this will help when making the colormap
metacluster_id_to_name.sort_values(by=f'{self.cluster_type}_meta_cluster', inplace=True)

# add the unassigned color to the metacluster_colors dict
self.metacluster_colors.update({unassigned_id: self.unassigned_color})

Expand All @@ -147,14 +144,30 @@ def _metacluster_cmap_generator(self) -> None:
f"{self.cluster_type}_meta_cluster"
].map(self.metacluster_colors)

# Convert the list of tuples to a numpy array, each index is a color
# grab mask cluster_id integers and merge with raw_cmap
cluster_id_to_metacluster_map = cluster_id_to_name[
[f"{self.cluster_type}_meta_cluster", "cluster_id"]].drop_duplicates()
unassigned_cluster_id: int = int(cluster_id_to_name["cluster_id"].max() + 1)
cluster_id_to_metacluster_map = pd.concat(
[
cluster_id_to_metacluster_map,
pd.DataFrame(
data={
f"{self.cluster_type}_meta_cluster": [unassigned_cluster_id, 0],
"cluster_id": [unassigned_id, 0],
camisowers marked this conversation as resolved.
Show resolved Hide resolved
}
)
]
)
cluster_id_to_color = metacluster_id_to_name.merge(
cluster_id_to_metacluster_map, on=[f"{self.cluster_type}_meta_cluster"])

mc_colors: np.ndarray = np.array(metacluster_id_to_name['color'].to_list())
# sort by cluster_id ascending, so colors align with mask integers
cluster_id_to_color.sort_values(by="cluster_id", inplace=True)
cluster_id_to_color.reset_index(drop=True, inplace=True)

metacluster_to_index = {}
metacluster_id_to_name.reset_index(drop=True, inplace=True)
for index, row in metacluster_id_to_name.reset_index(drop=True).iterrows():
metacluster_to_index[row[f'{self.cluster_type}_meta_cluster']] = index
# Convert the list of tuples to a numpy array, each index is a color
mc_colors: np.ndarray = np.array(cluster_id_to_color['color'].to_list())
camisowers marked this conversation as resolved.
Show resolved Hide resolved

# generate the colormap
cmap = colors.ListedColormap(mc_colors)
Expand All @@ -166,26 +179,9 @@ def _metacluster_cmap_generator(self) -> None:
# Assign created values to dataclass attributes
self.metacluster_id_to_name = metacluster_id_to_name
self.mc_colors = mc_colors
self.metacluster_to_index = metacluster_to_index
self.cmap = cmap
self.norm = norm

def assign_metacluster_cmap(self, fov_img: np.ndarray) -> np.ndarray:
"""Assigns the metacluster colormap to the provided image.

Args:
fov_img (np.ndarray): The metacluster image to assign the colormap index to.

Returns:
np.ndarray: The image with the colormap index assigned.
"""
# explicitly relabel each value in fov_img with its index in mc_colors
# to ensure proper indexing into colormap
relabeled_fov = np.copy(fov_img)
for mc, mc_color_idx in self.metacluster_to_index.items():
relabeled_fov[fov_img == mc] = mc_color_idx
return relabeled_fov


def create_cmap(cmap: Union[np.ndarray, list[str], str],
n_clusters: int) -> tuple[colors.ListedColormap, colors.BoundaryNorm]:
Expand Down Expand Up @@ -453,10 +449,8 @@ def plot_pixel_cell_cluster(
if erode:
fov = erode_mask(seg_mask=fov, connectivity=2, mode="thick", background=0)

fov_img = mcc.assign_metacluster_cmap(fov_img=fov)

fig: Figure = plot_cluster(
image=fov_img,
image=fov,
fov=fov_name,
cmap=mcc.cmap,
norm=mcc.norm,
Expand Down Expand Up @@ -740,7 +734,7 @@ def create_mantis_dir(fovs: List[str], mantis_project_path: Union[str, pathlib.P
if not new_mask_suffix:
new_mask_suffix = mask_suffix

cluster_id_key = 'cluster_id' if cluster_type == 'cell' else 'pixel_meta_cluster'
cluster_id_key = 'cluster_id'
map_df = map_df.loc[:, [cluster_id_key, f'{cluster_type}_meta_cluster_rename']]
# remove duplicates from df
map_df = map_df.drop_duplicates()
Expand Down Expand Up @@ -895,11 +889,8 @@ def save_colored_masks(
xr_channel_names=None,
)

# The values in the colored_mask are the indices of the colors in mcc.mc_colors
# Make a new array with the actual colors, multiply by uint8 max to get 0-255 range

colored_mask: np.ndarray = (mcc.mc_colors[mcc.assign_metacluster_cmap(
mask.values.squeeze())] * 255.999).astype(np.uint8)
colored_mask: np.ndarray = (mcc.mc_colors[mask.squeeze()] * 255.999).astype(np.uint8)

image_utils.save_image(
fname=save_dir / f"{fov}_{cluster_type}_mask_colored.tiff",
Expand Down
3 changes: 2 additions & 1 deletion templates/2_Pixie_Cluster_Pixels.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,7 @@
" tiff_dir=tiff_dir,\n",
" chan_file=chan_file,\n",
" pixel_data_dir=pixel_data_dir,\n",
" cluster_id_to_name_path=os.path.join(base_dir, pixel_meta_cluster_remap_name),\n",
" pixel_cluster_col='pixel_meta_cluster',\n",
" sub_dir='pixel_masks',\n",
" name_suffix='_pixel_mask',\n",
Expand Down Expand Up @@ -927,7 +928,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.6"
"version": "3.11.8"
},
"nbdime-conflicts": {
"local_diff": [
Expand Down
18 changes: 18 additions & 0 deletions tests/utils/data_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,19 @@ def test_generate_and_save_pixel_cluster_masks(sub_dir, name_suffix):
consensus_data['row_index'] = np.random.randint(low=0, high=chan_dims[0], size=100)
consensus_data['column_index'] = np.random.randint(low=0, high=chan_dims[1], size=100)

# create pixel mapping file
cluster_id_to_name_path = os.path.join(temp_dir, 'mapping.csv')
df = pd.DataFrame.from_dict(
{
"pixel_som_cluster": np.arange(1, 11),
"pixel_meta_cluster": np.repeat(np.arange(5) + 1, 2),
"pixel_meta_cluster_rename": [
"meta" + str(i) for i in np.repeat(np.arange(5) + 1, 2)
]
}
)
df.to_csv(cluster_id_to_name_path, index=False)

feather.write_dataframe(
consensus_data, os.path.join(temp_dir, 'pixel_mat_consensus', fov + '.feather')
)
Expand All @@ -518,6 +531,7 @@ def test_generate_and_save_pixel_cluster_masks(sub_dir, name_suffix):
tiff_dir=temp_dir,
chan_file='chan0.tiff',
pixel_data_dir='pixel_mat_consensus',
cluster_id_to_name_path=cluster_id_to_name_path,
pixel_cluster_col='pixel_meta_cluster',
sub_dir=sub_dir,
name_suffix=name_suffix
Expand All @@ -535,6 +549,10 @@ def test_generate_and_save_pixel_cluster_masks(sub_dir, name_suffix):
assert pixel_mask.shape == actual_img_dims
assert np.all(pixel_mask <= 5)

# check that mapping was updated with cluster_id
mapping = pd.read_csv(cluster_id_to_name_path)
assert "cluster_id" in mapping.columns


@parametrize('sub_dir', [None, 'sub_dir'])
@parametrize('name_suffix', ['', 'sample_suffix'])
Expand Down
Loading
Loading