Skip to content

Commit

Permalink
Added plotting of datasets in utils (#146)
Browse files Browse the repository at this point in the history
* Added plotting of datasets in utils and Added matplotlib as dependency for plotting echograms
  • Loading branch information
mihaiboldeanu authored Nov 29, 2023
1 parent b4bfc4c commit 5985cb6
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 2 deletions.
1 change: 1 addition & 0 deletions environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ dependencies:
- pydantic>2
- echopype
- haversine
- matplotlib
- black
- bottleneck
- check-manifest
Expand Down
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ dependencies:
- pydantic>2
- echopype
- haversine
- matplotlib
- pip
- git
- pip:
Expand Down
82 changes: 81 additions & 1 deletion oceanstream/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Dict, List, Union
from pathlib import Path
from typing import Dict, List, Optional, Union

import matplotlib.pyplot as plt
import numpy as np
import xarray as xr

Expand Down Expand Up @@ -95,3 +97,81 @@ def tfc(mask: xr.DataArray):
count_false = mask.size - count_true
true_false_counts = (count_true, count_false)
return true_false_counts


def plot_all_channels(
dataset1: xr.Dataset,
dataset2: Optional[xr.Dataset] = None,
variable_name: str = "Sv",
name: str = "",
save_path: Optional[Union[str, Path]] = "",
**kwargs,
):
"""
Plots echograms for all channels from one or two xarray Datasets.
This function iterates over channels in the specified variable of the given dataset(s) and creates echogram plots.
Each channel's data is plotted in a separate figure. When two datasets are provided, their respective echograms
for each channel are plotted side by side for comparison.
Parameters:
- dataset1 (xr.Dataset): The first xarray Dataset to plot.
- dataset2 (xr.Dataset, optional): The second xarray Dataset to plot alongside the first. Defaults to None.
- variable_name (str, optional): The name of the variable to plot from the dataset. Defaults to "Sv".
- name (str, optional): Base name for the output plot files. Defaults to empty string"".
- save_path ((str, Path) optional): Path where to save the images default is current working dir.
- **kwargs: Arbitrary keyword arguments. Commonly used for plot customization like `vmin`, `vmax`, and `cmap`.
Saves:
- PNG files for each channel's echogram, named using the variable name, the `name` parameter and channel name.
Example:
>> plot_all_channels(dataset1, dataset2, variable_name="Sv", name="echogram", vmin=-70, vmax=-30, cmap='inferno')
This will create and save echogram plots comparing dataset1 and dataset2 for each channel, using specified plot settings.
Note:
- If only one dataset is provided, echograms for that dataset alone will be plotted.
- The function handles plotting parameters such as color range (`vmin` and `vmax`) and colormap (`cmap`) via kwargs.
"""
for ch in dataset1[variable_name].channel.values:
plt.figure(figsize=(20, 10))

# Configure plotting parameters
plot_params = {
"vmin": kwargs.get("vmin", -100),
"vmax": kwargs.get("vmax", -40),
"cmap": kwargs.get("cmap", "viridis"),
}

if dataset2:
# First subplot for dataset1
ax1 = plt.subplot(1, 2, 1)
mappable1 = ax1.pcolormesh(
np.rot90(dataset1[variable_name].sel(channel=ch).values), **plot_params
)
plt.title(f"Original Data {ch}")

# Second subplot for dataset2
ax2 = plt.subplot(1, 2, 2)
ax2.pcolormesh(np.rot90(dataset2[variable_name].sel(channel=ch).values), **plot_params)
plt.title(f"Downsampled Data {ch}")

# Create a common colorbar
plt.colorbar(mappable1, ax=[ax1, ax2], orientation="vertical")

else:
ax = plt.subplot(1, 1, 1)
plt.pcolormesh(np.rot90(dataset1[variable_name].sel(channel=ch).values), **plot_params)
plt.title(f"{variable_name} Data {ch}")

# Create a colorbar
plt.colorbar(ax=ax, orientation="vertical")

# Save the figure
if save_path:
used_path = Path(save_path)
used_path = used_path / f"{name}_{variable_name}_channel_{ch}.png"
else:
used_path = f"{name}_{variable_name}_channel_{ch}.png"
plt.savefig(used_path)
plt.close()
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ geopy
pydantic>2
echopype
dask-image
matplotlib

# Development dependencies
black
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ pydantic>2
echopype
haversine
dask-image
matplotlib
13 changes: 12 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import os

from oceanstream.L2_calibrated_data.noise_masks import create_seabed_mask
from oceanstream.L2_calibrated_data.sv_computation import compute_sv
from oceanstream.utils import *


def test_dict_to_formatted_list():
# Define a sample dictionary
sample_dict = {
Expand Down Expand Up @@ -80,3 +81,13 @@ def test_tfc():
)
res = tfc(data)
assert res == (2, 1)


def test_plotting(ed_ek_60_for_Sv):
current_directory = os.path.dirname(os.path.abspath(__file__))
TEST_DATA_FOLDER = os.path.join(current_directory, "..", "test_data")
source_Sv = compute_sv(ed_ek_60_for_Sv)
plot_all_channels(source_Sv,name="test_image", save_path=TEST_DATA_FOLDER)

plot_all_channels(source_Sv,source_Sv, name="test_image_double", save_path=TEST_DATA_FOLDER)

0 comments on commit 5985cb6

Please sign in to comment.