diff --git a/lumicks/pylake/tests/test_imaging_confocal/conftest.py b/lumicks/pylake/tests/test_imaging_confocal/conftest.py index a8f3da212..1393f0205 100644 --- a/lumicks/pylake/tests/test_imaging_confocal/conftest.py +++ b/lumicks/pylake/tests/test_imaging_confocal/conftest.py @@ -1,13 +1,23 @@ +from itertools import permutations + import numpy as np import pytest +from lumicks.pylake.channel import Slice, Continuous from lumicks.pylake.point_scan import PointScan from ..data.mock_file import MockDataFile_v2 -from ..data.mock_confocal import MockConfocalFile, generate_scan_json, generate_kymo_with_ref +from ..data.mock_confocal import ( + MockConfocalFile, + generate_scan_json, + generate_kymo_with_ref, + generate_scan_with_ref, +) start = np.int64(20e9) dt = np.int64(62.5e6) +axes_map = {"X": 0, "Y": 1, "Z": 2} +channel_map = {"r": 0, "g": 1, "b": 2} @pytest.fixture(scope="module") @@ -228,3 +238,109 @@ def test_point_scan(): } return point_scan, reference + + +@pytest.fixture(scope="module") +def test_scans(): + image = np.random.poisson(5, size=(4, 5, 3)) + return { + (name := f"fast {axes[0]} slow {axes[1]}"): generate_scan_with_ref( + name, + image, + pixel_sizes_nm=[50, 50], + axes=[axes_map[k] for k in axes], + start=start, + dt=dt, + samples_per_pixel=5, + line_padding=50, + multi_color=True, + ) + for axes in permutations(axes_map.keys(), 2) + } + + +@pytest.fixture(scope="module") +def test_scans_multiframe(): + image = np.random.poisson(5, size=(2, 4, 5, 3)) + return { + (name := f"fast {axes[0]} slow {axes[1]} multiframe"): generate_scan_with_ref( + name, + image, + pixel_sizes_nm=[50, 50], + axes=[axes_map[k] for k in axes], + start=start, + dt=dt, + samples_per_pixel=5, + line_padding=50, + multi_color=True, + ) + for axes in permutations(axes_map.keys(), 2) + } + + +@pytest.fixture(scope="module") +def test_scan_missing_channels(): + empty = Slice(Continuous([], start=start, dt=dt)) + + def make_data(*missing_channels): + image = np.random.poisson(5, size=(4, 5, 3)) + for channel in missing_channels: + image[:, :, channel_map[channel[0]]] = 0 + + scan, ref = generate_scan_with_ref( + f"missing {', '.join(missing_channels)}", + image, + pixel_sizes_nm=[50, 50], + axes=[1, 0], + start=start, + dt=dt, + samples_per_pixel=5, + line_padding=50, + multi_color=True, + ) + + for channel in missing_channels: + setattr(scan.file, f"{channel}_photon_count", empty) + + return scan, ref + + return {key: make_data(*key) for key in [("red",), ("red", "blue"), ("red", "green", "blue")]} + + +@pytest.fixture(scope="module") +def test_scan_truncated(): + image = np.random.poisson(5, size=(2, 4, 5, 3)) + scan, ref = generate_scan_with_ref( + "truncated", + image, + pixel_sizes_nm=[50, 50], + axes=[1, 0], + start=start, + dt=dt, + samples_per_pixel=5, + line_padding=50, + multi_color=True, + ) + scan.start = start - dt + return scan, ref + + +@pytest.fixture(scope="module") +def test_scan_sted_bug(): + image = np.random.poisson(5, size=(2, 4, 5, 3)) + scan, ref = generate_scan_with_ref( + "sted bug", + image, + pixel_sizes_nm=[50, 50], + axes=[1, 0], + start=start, + dt=dt, + samples_per_pixel=5, + line_padding=50, + multi_color=True, + ) + corrected_start = scan.red_photon_count.timestamps[5] + + # start *between* samples + scan.start = corrected_start - np.int64(dt - 1e5) + return scan, ref, corrected_start diff --git a/lumicks/pylake/tests/test_imaging_confocal/test_scan.py b/lumicks/pylake/tests/test_imaging_confocal/test_scan.py new file mode 100644 index 000000000..baf9b0942 --- /dev/null +++ b/lumicks/pylake/tests/test_imaging_confocal/test_scan.py @@ -0,0 +1,75 @@ +import numpy as np +import pytest + + +def test_scan_attrs(test_scans, test_scans_multiframe): + for key, (scan, ref) in (test_scans | test_scans_multiframe).items(): + assert ( + repr(scan) + == f"Scan(pixels=({ref.metadata.pixels_per_line}, {ref.metadata.lines_per_frame}))" + ) + np.testing.assert_allclose(scan.timestamps, ref.timestamps.data) + assert scan.num_frames == ref.metadata.number_of_frames + np.testing.assert_equal(scan.pixel_time_seconds, ref.timestamps.pixel_time_seconds) + assert scan.pixels_per_line == ref.metadata.pixels_per_line + assert scan.lines_per_frame == ref.metadata.lines_per_frame + assert len(scan.infowave) == len(ref.infowave.data) + assert scan.get_image("rgb").shape == ref.image.shape + assert scan.get_image("red").shape == ref.image.shape[:-1] + assert scan.get_image("blue").shape == ref.image.shape[:-1] + assert scan.get_image("green").shape == ref.image.shape[:-1] + + assert scan.fast_axis == ref.metadata.fast_axis + np.testing.assert_allclose(scan.pixelsize_um, ref.metadata.pixelsize_um) + for key, value in ref.metadata.center_point_um.items(): + np.testing.assert_allclose(scan.center_point_um[key], value) + np.testing.assert_allclose( + scan.size_um, np.array(ref.metadata.num_pixels) * ref.metadata.pixelsize_um + ) + + np.testing.assert_equal( + scan.frame_timestamp_ranges(include_dead_time=True), + ref.timestamps.timestamp_ranges # For the single frame case, there is no dead time + if scan.num_frames == 1 + else ref.timestamps.timestamp_ranges_deadtime, + ) + np.testing.assert_equal(scan.frame_timestamp_ranges(), ref.timestamps.timestamp_ranges) + + +def test_missing_channels(test_scan_missing_channels): + channel_map = {"r": 0, "g": 1, "b": 2} + + for missing_channels, (scan, ref) in test_scan_missing_channels.items(): + rgb = scan.get_image("rgb") + assert rgb.shape == ref.image.shape + np.testing.assert_equal(scan.get_image("rgb"), ref.image) + + for channel in missing_channels: + assert not np.any(rgb[:, :, channel_map[channel[0]]]) + np.testing.assert_equal(scan.get_image(channel), np.zeros(ref.image.shape[:2])) + + +def test_damaged_scan(test_scan_truncated): + # Assume the user incorrectly exported only a partial scan + scan, ref = test_scan_truncated + with pytest.raises( + RuntimeError, + match=( + "Start of the scan was truncated. Reconstruction cannot proceed. " + "Did you export the entire scan time in Bluelake?" + ), + ): + scan.get_image("red").shape + + +def test_sted_bug(test_scan_sted_bug): + # Test for workaround for a bug in the STED delay mechanism which could result in scan start + # times ending up within the sample time. + scan, ref, corrected_start = test_scan_sted_bug + + # should not raise, but change the start appropriately to work around sted bug + # start is only adjusted only during image reconstruction + original_start = scan.start + scan.get_image("red").shape + assert scan.start != original_start + np.testing.assert_allclose(scan.start, corrected_start) diff --git a/lumicks/pylake/tests/test_imaging_confocal_old/test_scan.py b/lumicks/pylake/tests/test_imaging_confocal_old/test_scan.py index 9fa9449c0..5f3df47bd 100644 --- a/lumicks/pylake/tests/test_imaging_confocal_old/test_scan.py +++ b/lumicks/pylake/tests/test_imaging_confocal_old/test_scan.py @@ -10,126 +10,6 @@ from ..data.mock_confocal import generate_scan -def test_scan_attrs(test_scans): - scan = test_scans["fast Y slow X"] - assert repr(scan) == "Scan(pixels=(4, 5))" - - # fmt: off - reference_timestamps = np.array( - [ - [20062500000, 20812500000, 22187500000, 23562500000, 24937500000], - [20250000000, 21625000000, 22375000000, 23750000000, 25125000000], - [20437500000, 21812500000, 23187500000, 23937500000, 25312500000], - [20625000000, 22000000000, 23375000000, 24750000000, 25500000000], - ] - ).T - - np.testing.assert_allclose(scan.timestamps, np.transpose(reference_timestamps)) - assert scan.num_frames == 1 - assert scan.pixels_per_line == 4 - assert scan.lines_per_frame == 5 - assert len(scan.infowave) == 90 - assert scan.get_image("rgb").shape == (4, 5, 3) - assert scan.get_image("red").shape == (4, 5) - assert scan.get_image("blue").shape == (4, 5) - assert scan.get_image("green").shape == (4, 5) - - assert scan.fast_axis == "Y" - np.testing.assert_allclose(scan.pixelsize_um, [197 / 1000, 191 / 1000]) - np.testing.assert_allclose(scan.center_point_um["x"], 58.075877109272604) - np.testing.assert_allclose(scan.center_point_um["y"], 31.978375270573267) - np.testing.assert_allclose(scan.center_point_um["z"], 0) - np.testing.assert_allclose(scan.size_um, [0.197 * 5, 0.191 * 4]) - - scan = test_scans["fast Y slow X multiframe"] - reference_timestamps2 = np.zeros((2, 4, 3)) - reference_timestamps2[0, :, :] = reference_timestamps.T[:, :3] - reference_timestamps2[1, :, :2] = reference_timestamps.T[:, 3:] - - np.testing.assert_allclose(scan.timestamps, reference_timestamps2) - assert scan.num_frames == 2 - assert scan.pixels_per_line == 4 - assert scan.lines_per_frame == 3 - assert len(scan.infowave) == 90 - assert scan.get_image("rgb").shape == (2, 4, 3, 3) - assert scan.get_image("red").shape == (2, 4, 3) - assert scan.get_image("blue").shape == (2, 4, 3) - assert scan.get_image("green").shape == (2, 4, 3) - assert scan.fast_axis == "Y" - np.testing.assert_allclose(scan.pixelsize_um, [197 / 1000, 191 / 1000]) - np.testing.assert_allclose(scan.center_point_um["x"], 58.075877109272604) - np.testing.assert_allclose(scan.center_point_um["y"], 31.978375270573267) - np.testing.assert_allclose(scan.center_point_um["z"], 0) - np.testing.assert_allclose(scan.size_um, [0.197 * 3, 0.191 * 4]) - - scan = test_scans["fast X slow Z multiframe"] - reference_timestamps2 = np.zeros((2, 4, 3)) - reference_timestamps2[0, :, :] = reference_timestamps.T[:, :3] - reference_timestamps2[1, :, :2] = reference_timestamps.T[:, 3:] - reference_timestamps2 = reference_timestamps2.transpose([0, 2, 1]) - - np.testing.assert_allclose(scan.timestamps, reference_timestamps2) - assert scan.num_frames == 2 - assert scan.pixels_per_line == 4 - assert scan.lines_per_frame == 3 - assert len(scan.infowave) == 90 - assert scan.get_image("rgb").shape == (2, 3, 4, 3) - assert scan.get_image("red").shape == (2, 3, 4) - assert scan.get_image("blue").shape == (2, 3, 4) - assert scan.get_image("green").shape == (2, 3, 4) - assert scan.fast_axis == "X" - np.testing.assert_allclose(scan.pixelsize_um, [191 / 1000, 197 / 1000]) - np.testing.assert_allclose(scan.center_point_um["x"], 58.075877109272604) - np.testing.assert_allclose(scan.center_point_um["y"], 31.978375270573267) - np.testing.assert_allclose(scan.center_point_um["z"], 0) - np.testing.assert_allclose(scan.size_um, [0.191 * 4, 0.197 * 3]) - - scan = test_scans["fast Y slow Z multiframe"] - reference_timestamps2 = np.zeros((2, 4, 3)) - reference_timestamps2[0, :, :] = reference_timestamps.T[:, :3] - reference_timestamps2[1, :, :2] = reference_timestamps.T[:, 3:] - reference_timestamps2 = reference_timestamps2.transpose([0, 2, 1]) - - np.testing.assert_allclose(scan.timestamps, reference_timestamps2) - assert scan.num_frames == 2 - assert scan.pixels_per_line == 4 - assert scan.lines_per_frame == 3 - assert len(scan.infowave) == 90 - assert scan.get_image("rgb").shape == (2, 3, 4, 3) - assert scan.get_image("red").shape == (2, 3, 4) - assert scan.get_image("blue").shape == (2, 3, 4) - assert scan.get_image("green").shape == (2, 3, 4) - assert scan.fast_axis == "Y" - np.testing.assert_allclose(scan.pixelsize_um, [191 / 1000, 197 / 1000]) - np.testing.assert_allclose(scan.center_point_um["x"], 58.075877109272604) - np.testing.assert_allclose(scan.center_point_um["y"], 31.978375270573267) - np.testing.assert_allclose(scan.center_point_um["z"], 0) - np.testing.assert_allclose(scan.size_um, [0.191 * 4, 0.197 * 3]) - - scan = test_scans["red channel missing"] - rgb = scan.get_image("rgb") - assert rgb.shape == (4, 5, 3) - assert not np.any(rgb[:, :, 0]) - np.testing.assert_equal(scan.get_image("red"), np.zeros((4, 5))) - - assert scan.get_image("blue").shape == (4, 5) - assert scan.get_image("green").shape == (4, 5) - - scan = test_scans["rb channels missing"] - rgb = scan.get_image("rgb") - assert rgb.shape == (4, 5, 3) - assert not np.any(rgb[:, :, 0]) - assert not np.any(rgb[:, :, 2]) - np.testing.assert_equal(scan.get_image("red"), np.zeros((4, 5))) - np.testing.assert_equal(scan.get_image("blue"), np.zeros((4, 5))) - assert scan.get_image("green").shape == (4, 5) - - scan = test_scans["all channels missing"] - np.testing.assert_equal(scan.get_image("red"), np.zeros((4, 5))) - np.testing.assert_equal(scan.get_image("green"), np.zeros((4, 5))) - np.testing.assert_equal(scan.get_image("blue"), np.zeros((4, 5))) - - def test_slicing(test_scans): scan0 = test_scans["multiframe_poisson"] assert scan0.num_frames == 10 @@ -176,22 +56,6 @@ def compare_frames(original_frames, new_scan): assert scan0.num_frames == 10 -def test_damaged_scan(test_scans): - # Assume the user incorrectly exported only a partial scan (62500000 is the time step) - scan = test_scans["truncated_scan"] - with pytest.raises(RuntimeError): - scan.get_image("red").shape - - # Test for workaround for a bug in the STED delay mechanism which could result in scan start times ending up - # within the sample time. - scan = test_scans["sted bug"] - middle = test_scans["fast Y slow X"].red_photon_count.timestamps[5] - scan.get_image( - "red" - ).shape # should not raise, but change the start appropriately to work around sted bug - np.testing.assert_allclose(scan.start, middle) - - def test_plotting(test_scans): scan = test_scans["fast Y slow X multiframe"] scan.plot(channel="blue") @@ -304,85 +168,6 @@ def test_movie_export(tmpdir_factory, test_scans): scan.export_video("gray", "dummy.gif") # Gray is not a color! -@pytest.mark.parametrize( - "dim_x, dim_y, line_padding, start, dt, samples_per_pixel", - [ - (5, 6, 3, 14, 4, 4), - (3, 4, 60, 1592916040906356300, 12800, 30), - (3, 2, 60, 1592916040906356300, 12800, 3000), - ], -) -def test_single_frame_times(dim_x, dim_y, line_padding, start, dt, samples_per_pixel): - img = np.ones((dim_x, dim_y)) - scan = generate_scan( - "test", - img, - [1, 1], - start=start, - dt=dt, - samples_per_pixel=samples_per_pixel, - line_padding=line_padding, - ) - frame_times = scan.frame_timestamp_ranges() - assert len(frame_times) == 1 - assert frame_times[0][0] == start + line_padding * dt - line_time = dt * (img.shape[1] * samples_per_pixel + 2 * line_padding) * img.shape[0] - assert frame_times[0][1] == start + line_time - line_padding * dt - - # For the single frame case, there is no dead time, so these are identical - frame_times_inclusive = scan.frame_timestamp_ranges(include_dead_time=True) - assert len(frame_times_inclusive) == 1 - assert frame_times_inclusive[0][0] == frame_times[0][0] - assert frame_times_inclusive[0][1] == frame_times[0][1] - - -@pytest.mark.parametrize( - "dim_x, dim_y, frames, line_padding, start, dt, samples_per_pixel", - [ - (5, 6, 3, 3, 14, 4, 4), - (3, 4, 4, 60, 1592916040906356300, 12800, 30), - (3, 2, 3, 60, 1592916040906356300, 12800, 3000), - ], -) -def test_multiple_frame_times(dim_x, dim_y, frames, line_padding, start, dt, samples_per_pixel): - img = np.ones((frames, dim_x, dim_y)) - scan = generate_scan( - "test", - img, - [1, 1], - start=start, - dt=dt, - samples_per_pixel=samples_per_pixel, - line_padding=line_padding, - ) - frame_times = scan.frame_timestamp_ranges() - - line_time = dt * (img.shape[2] * samples_per_pixel + 2 * line_padding) * img.shape[1] - assert scan.num_frames == frames - assert len(frame_times) == scan.num_frames - assert frame_times[0][0] == start + line_padding * dt - assert frame_times[0][1] == start + line_time - line_padding * dt - assert frame_times[1][0] == start + line_padding * dt + line_time - assert frame_times[1][1] == start + 2 * line_time - line_padding * dt - assert frame_times[-1][0] == start + line_padding * dt + (len(frame_times) - 1) * line_time - assert frame_times[-1][1] == start + len(frame_times) * line_time - line_padding * dt - - def compare_inclusive(frame_times_inclusive): - # Start times should be the same - assert len(frame_times_inclusive) == scan.num_frames - assert frame_times_inclusive[0][0] == frame_times[0][0] - assert frame_times_inclusive[1][0] == frame_times[1][0] - assert frame_times_inclusive[-1][0] == frame_times[-1][0] - - assert frame_times_inclusive[0][1] == frame_times[1][0] - assert frame_times_inclusive[1][1] == frame_times[2][0] - assert frame_times_inclusive[-1][1] == frame_times[-1][0] + ( - frame_times[1][0] - frame_times[0][0] - ) - - compare_inclusive(scan.frame_timestamp_ranges(include_dead_time=True)) - - def test_scan_plot_rgb_absolute_color_adjustment(test_scans): """Tests whether we can set an absolute color range for an RGB plot.""" scan = test_scans["fast Y slow X"] @@ -465,20 +250,6 @@ def test_plot_single_channel_percentile_color_adjustment(test_scans): plt.close(fig) -@pytest.mark.parametrize( - "scan, pixel_time", - [ - ("fast Y slow X", 0.1875), - ("fast X slow Z multiframe", 0.1875), - ("fast Y slow X multiframe", 0.1875), - ("fast Y slow Z multiframe", 0.1875), - ("fast Y slow X", 0.1875), - ], -) -def test_scan_pixel_time(test_scans, scan, pixel_time): - np.testing.assert_allclose(test_scans[scan].pixel_time_seconds, pixel_time) - - @pytest.mark.parametrize( "x_min, x_max, y_min, y_max", [