From a7a9a382f7b08a121926832404425bc64946d9b6 Mon Sep 17 00:00:00 2001 From: cgsandford <35029690+cgsandford@users.noreply.github.com> Date: Wed, 16 Dec 2020 15:01:39 +0000 Subject: [PATCH] Refactor CubeCombiner (#1383) * Factored multiplication case out of CubeCombiner. All CLI tests pass. Need to sort unit tests and then make code nicer. * Split existing unit tests between plugins and they pass. * Reduced "broadcast_to_coords" to "broadcast_to_threshold" boolean, consistent with use case and CLI interface. * Simplified setting up of coords for broadcast to use threshold only. * Factored out broadcasting into CubeMultiplier. Coord checking needs tidying up but all tests pass. * Removed unused method. * Pythonified coordinate checking. * Removed "self.operation" and summarised use more clearly. * Addressed comments from first review. * Response to second review. --- improver/cli/combine.py | 19 +- improver/cube_combiner.py | 282 ++++++++++-------- .../cube_combiner/test_CubeCombiner.py | 107 +------ .../cube_combiner/test_CubeMultiplier.py | 132 ++++++++ 4 files changed, 302 insertions(+), 238 deletions(-) create mode 100644 improver_tests/cube_combiner/test_CubeMultiplier.py diff --git a/improver/cli/combine.py b/improver/cli/combine.py index 87bfa57cf6..65c620b9a2 100755 --- a/improver/cli/combine.py +++ b/improver/cli/combine.py @@ -71,19 +71,22 @@ def process( result (iris.cube.Cube): Returns a cube with the combined data. """ - from improver.cube_combiner import CubeCombiner + from improver.cube_combiner import CubeCombiner, CubeMultiplier from iris.cube import CubeList if not cubes: raise TypeError("A cube is needed to be combined.") if new_name is None: new_name = cubes[0].name() - broadcast_to_coords = ["threshold"] if broadcast_to_threshold else None - result = CubeCombiner(operation, warnings_on=check_metadata)( - CubeList(cubes), - new_name, - broadcast_to_coords=broadcast_to_coords, - use_midpoint=use_midpoint, - ) + + if operation == "*" or operation == "multiply": + result = CubeMultiplier()( + CubeList(cubes), new_name, broadcast_to_threshold=broadcast_to_threshold, + ) + + else: + result = CubeCombiner(operation)( + CubeList(cubes), new_name, use_midpoint=use_midpoint, + ) return result diff --git a/improver/cube_combiner.py b/improver/cube_combiner.py index cafa6ee8f0..77db4e17d3 100644 --- a/improver/cube_combiner.py +++ b/improver/cube_combiner.py @@ -28,7 +28,9 @@ # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # POSSIBILITY OF SUCH DAMAGE. -"""Module containing plugin for CubeCombiner.""" +"""Module containing plugins for combining cubes""" + +from operator import eq import iris import numpy as np @@ -48,61 +50,48 @@ class CubeCombiner(BasePlugin): - - """Plugin for combining cubes. - - """ + """Plugin for combining cubes using linear operators""" COMBINE_OPERATORS = { "+": np.add, "add": np.add, "-": np.subtract, "subtract": np.subtract, - "*": np.multiply, - "multiply": np.multiply, "max": np.maximum, "min": np.minimum, "mean": np.add, } # mean is calculated in two steps: sum and normalise - def __init__(self, operation, warnings_on=False): - """ - Create a CubeCombiner plugin + def __init__(self, operation): + """Create a CubeCombiner plugin Args: operation (str): Operation (+, - etc) to apply to the incoming cubes. - warnings_on (bool): - If True output warnings for mismatching metadata. Raises: - ValueError: Unknown operation. - + ValueError: if operation is not recognised in dictionary """ try: self.operator = self.COMBINE_OPERATORS[operation] except KeyError: msg = "Unknown operation {}".format(operation) raise ValueError(msg) - self.operation = operation - self.broadcast_coords = None - self.warnings_on = warnings_on - - def __repr__(self): - """Represent the configured plugin instance as a string.""" - desc = "".format( - self.operation, self.warnings_on - ) - return desc - def _check_dimensions_match(self, cube_list): + self.normalise = operation == "mean" + + @staticmethod + def _check_dimensions_match(cube_list, comparators=[eq]): """ - Check all coordinate dimensions on the input cubes are equal or broadcastable + Check all coordinate dimensions on the input cubes match according to + the comparators specified. Args: - cube_list (iris.cube.CubeList or list): + cube_list (list of iris.cube.Cube): List of cubes to compare - + comparators (list of callable): + Comparison operators, at least one of which must return "True" + for each coordinate in order for the match to be valid Raises: ValueError: If dimension coordinates do not match """ @@ -110,7 +99,7 @@ def _check_dimensions_match(self, cube_list): for cube in cube_list[1:]: coords = cube.coords(dim_coords=True) compare = [ - (a == b) or self._coords_are_broadcastable(a, b) + np.any([comp(a, b) for comp in comparators]) for a, b in zip(coords, ref_coords) ] if not np.all(compare): @@ -120,19 +109,6 @@ def _check_dimensions_match(self, cube_list): ) raise ValueError(msg) - @staticmethod - def _coords_are_broadcastable(coord1, coord2): - """ - Broadcastable coords will differ only in length, so create a copy of one with - the points and bounds of the other and compare. Also ensure length of at least - one of the coords is 1. - """ - coord_copy = coord1.copy(coord2.points, bounds=coord2.bounds) - - return (coord_copy == coord2) and ( - (len(coord1.points) == 1) or (len(coord2.points) == 1) - ) - @staticmethod def _get_expanded_coord_names(cube_list): """ @@ -141,7 +117,7 @@ def _get_expanded_coord_names(cube_list): that are present on all input cubes, but have different values. Args: - cube_list (iris.cube.CubeList or list): + cube_list (list of iris.cube.Cube): List of cubes to that will be combined Returns: @@ -167,60 +143,32 @@ def _get_expanded_coord_names(cube_list): expanded_coords.append(coord) return expanded_coords - def _setup_coords_for_broadcast(self, cube_list): + def _combine_cube_data(self, cube_list): """ - Adds a scalar DimCoord to any subsequent cube in cube_list so that they all include all of - the coords specified in self.broadcast_coords in the right order. + Perform cumulative operation to combine cube data Args: - cube_list: (iris.cube.CubeList) + cube_list (list of iris.cube.Cube) Returns: - iris.cube.CubeList - Updated version of cube_list + iris.cube.Cube + Raises: + TypeError: if the operation results in an escalated datatype """ - for coord in self.broadcast_coords: - target_cube = cube_list[0] - try: - if coord == "threshold": - target_coord = find_threshold_coordinate(target_cube) - else: - target_coord = target_cube.coord(coord) - except CoordinateNotFoundError: - raise CoordinateNotFoundError( - f"Cannot find coord {coord} in {repr(target_cube)} to broadcast to." - ) - new_list = CubeList([]) - for cube in cube_list: - try: - found_coord = cube.coord(target_coord) - except CoordinateNotFoundError: - new_coord = target_coord.copy([0], bounds=None) - cube = cube.copy() - cube.add_aux_coord(new_coord, None) - cube = iris.util.new_axis(cube, new_coord) - enforce_coordinate_ordering( - cube, [d.name() for d in target_cube.coords(dim_coords=True)] - ) - else: - if found_coord not in cube.dim_coords: - # We don't expect the coord to already exist in a scalar form as - # this would indicate that the broadcast-from cube is only valid - # for part of the new dimension and therefore should be rejected. - raise TypeError( - f"Cannot broadcast to coord {coord} as it already exists as an AuxCoord" - ) - new_list.append(cube) - cube_list = new_list - return cube_list + result = cube_list[0].copy() + for cube in cube_list[1:]: + result.data = self.operator(result.data, cube.data) + + if self.normalise: + result.data = result.data / len(cube_list) + + enforce_dtype(str(self.operator), cube_list, result) + + return result def process( - self, - cube_list, - new_diagnostic_name, - broadcast_to_coords=None, - use_midpoint=False, + self, cube_list, new_diagnostic_name, use_midpoint=False, ): """ Combine data and metadata from a list of input cubes into a single @@ -228,26 +176,11 @@ def process( first cube in the input list provides the template for the combined cube metadata. - NOTE the behaviour for the "multiply" operation is different from - other types of cube combination. The only valid use case for - "multiply" is to apply a factor that conditions an input probability - field - that is, to apply Bayes Theorem. The input probability is - therefore used as the source of ALL input metadata, and should always - be the first cube in the input list. The factor(s) by which this is - multiplied are not compared for any mis-match in scalar coordinates, - neither do they to contribute to expanded bounds. - - TODO the "multiply" case should be factored out into a separate plugin - given its substantial differences from other combine use cases. - Args: - cube_list (iris.cube.CubeList or list): + cube_list (list of iris.cube.Cube): List of cubes to combine. new_diagnostic_name (str): New name for the combined diagnostic. - broadcast_to_coords (list): - Specifies a list of coord names that exist only on the first cube that - the other cube(s) need(s) broadcasting to prior to the combine. use_midpoint (bool): Determines the nature of the points and bounds for expanded coordinates. If False, the upper bound of the coordinate is @@ -259,39 +192,138 @@ def process( Raises: ValueError: If the cube_list contains only one cube. - TypeError: If combining data results in float64 data. """ if len(cube_list) < 2: msg = "Expecting 2 or more cubes in cube_list" raise ValueError(msg) - self.broadcast_coords = broadcast_to_coords - if self.broadcast_coords: - cube_list = self._setup_coords_for_broadcast(cube_list) self._check_dimensions_match(cube_list) + result = self._combine_cube_data(cube_list) + expanded_coord_names = self._get_expanded_coord_names(cube_list) + if expanded_coord_names: + result = expand_bounds( + result, cube_list, expanded_coord_names, use_midpoint=use_midpoint + ) + result.rename(new_diagnostic_name) + return result - # perform operation (add, subtract, min, max, multiply) cumulatively - result = cube_list[0].copy() - for cube in cube_list[1:]: - result.data = self.operator(result.data, cube.data) - # normalise mean (for which self.operator is np.add) - if self.operation == "mean": - result.data = result.data / len(cube_list) +class CubeMultiplier(CubeCombiner): + """Class to multiply input cubes + + The behaviour for the "multiply" operation is different from + other types of cube combination. The only valid use case for + "multiply" is to apply a factor that conditions an input probability + field - that is, to apply Bayes Theorem. The input probability is + therefore used as the source of ALL input metadata, and should always + be the first cube in the input list. The factor(s) by which this is + multiplied are not compared for any mis-match in scalar coordinates. + + """ + + def __init__(self): + """Create a CubeMultiplier plugin""" + self.operator = np.multiply + self.normalise = False - # Check resulting dtype - enforce_dtype(self.operation, cube_list, result) + def _setup_coords_for_broadcast(self, cube_list): + """ + Adds a scalar threshold to any subsequent cube in cube_list so that they all + match the dimensions, in order, of the first cube in the list + + Args: + cube_list (list of iris.cube.Cube) + + Returns: + iris.cube.CubeList + Updated version of cube_list - # where the operation is "multiply", retain all coordinate metadata - # from the first cube in the list; otherwise expand coordinate bounds - if self.operation != "multiply": - expanded_coord_names = self._get_expanded_coord_names(cube_list) - if expanded_coord_names: - result = expand_bounds( - result, cube_list, expanded_coord_names, use_midpoint=use_midpoint + Raises: + CoordinateNotFoundError: if there is no threshold coordinate on the + first cube in the list + TypeError: if there is a scalar threshold coordinate on any of the + later cubes, which would indicate that the cube is only valid for + a single threshold and should not be broadcast to all thresholds. + """ + target_cube = cube_list[0] + try: + target_coord = find_threshold_coordinate(target_cube) + except CoordinateNotFoundError: + raise CoordinateNotFoundError( + f"Cannot find coord threshold in {repr(target_cube)} to broadcast to" + ) + + new_list = CubeList([]) + for cube in cube_list: + try: + found_coord = cube.coord(target_coord) + except CoordinateNotFoundError: + new_coord = target_coord.copy([0], bounds=None) + cube = cube.copy() + cube.add_aux_coord(new_coord, None) + cube = iris.util.new_axis(cube, new_coord) + enforce_coordinate_ordering( + cube, [d.name() for d in target_cube.coords(dim_coords=True)] ) + else: + if found_coord not in cube.dim_coords: + msg = "Cannot broadcast to coord threshold as it already exists as an AuxCoord" + raise TypeError(msg) + new_list.append(cube) + + return new_list + + @staticmethod + def _coords_are_broadcastable(coord1, coord2): + """ + Broadcastable coords will differ only in length, so create a copy of one with + the points and bounds of the other and compare. Also ensure length of at least + one of the coords is 1. + """ + coord_copy = coord1.copy(coord2.points, bounds=coord2.bounds) + + return (coord_copy == coord2) and ( + (len(coord1.points) == 1) or (len(coord2.points) == 1) + ) + + def process( + self, cube_list, new_diagnostic_name, broadcast_to_threshold=False, + ): + """ + Multiply data from a list of input cubes into a single cube. The first + cube in the input list provides the combined cube metadata. + + Args: + cube_list (iris.cube.CubeList or list): + List of cubes to combine. + new_diagnostic_name (str): + New name for the combined diagnostic. + broadcast_to_threshold (bool): + True if the first cube has a threshold coordinate to which the + following cube(s) need(s) to be broadcast prior to combining data. + + Returns: + iris.cube.Cube: + Cube containing the combined data. + + Raises: + ValueError: If the cube_list contains only one cube. + TypeError: If combining data results in float64 data. + """ + if len(cube_list) < 2: + msg = "Expecting 2 or more cubes in cube_list" + raise ValueError(msg) + + if broadcast_to_threshold: + cube_list = self._setup_coords_for_broadcast(cube_list) + + self._check_dimensions_match( + cube_list, comparators=[eq, self._coords_are_broadcastable] + ) + + result = self._combine_cube_data(cube_list) - if self.broadcast_coords and "threshold" in self.broadcast_coords: + if broadcast_to_threshold: probabilistic_name = cube_list[0].name() diagnostic_name = extract_diagnostic_name(probabilistic_name) diff --git a/improver_tests/cube_combiner/test_CubeCombiner.py b/improver_tests/cube_combiner/test_CubeCombiner.py index 4034e7cda7..f40e869f78 100644 --- a/improver_tests/cube_combiner/test_CubeCombiner.py +++ b/improver_tests/cube_combiner/test_CubeCombiner.py @@ -55,7 +55,7 @@ class Test__init__(IrisTest): def test_basic(self): """Test that the __init__ sets things up correctly""" plugin = CubeCombiner("+") - self.assertEqual(plugin.operation, "+") + self.assertEqual(plugin.operator, np.add) def test_raise_error_wrong_operation(self): """Test __init__ raises a ValueError for invalid operation""" @@ -64,17 +64,6 @@ def test_raise_error_wrong_operation(self): CubeCombiner("%") -class Test__repr__(IrisTest): - - """Test the repr method.""" - - def test_basic(self): - """Test that the __repr__ returns the expected string.""" - result = str(CubeCombiner("+")) - msg = "" - self.assertEqual(result, msg) - - class CombinerTest(ImproverTest): """Set up a common set of test cubes for subsequent test classes.""" @@ -205,10 +194,7 @@ def test_mixed_dtypes_overflow(self): cubelist = iris.cube.CubeList( [self.cube1, self.cube2.copy(np.ones_like(self.cube2.data, dtype=np.int32))] ) - msg = ( - r"Operation add on types \{dtype\(\'.*\'\)\} results in " - r"float64 data which cannot be safely coerced to float32" - ) + msg = "Operation .* results in float64 data" with self.assertRaisesRegex(TypeError, msg): plugin.process(cubelist, "new_cube_name") @@ -285,95 +271,6 @@ def test_exception_for_single_entry_cubelist(self): with self.assertRaisesRegex(ValueError, msg): plugin.process(cubelist, "new_cube_name") - def test_broadcast_coord(self): - """Test that plugin broadcasts to a coord and doesn't change the inputs. - Using the broadcast_to_coords argument including a value of "threshold" - will result in the returned cube maintaining the probabilistic elements - of the name of the first input cube.""" - plugin = CubeCombiner("*") - cube = self.cube4[:, 0, ...].copy() - cube.data = np.ones_like(cube.data) - cube.remove_coord("lwe_thickness_of_precipitation_amount") - cubelist = iris.cube.CubeList([self.cube4.copy(), cube]) - input_copy = deepcopy(cubelist) - result = plugin.process( - cubelist, "new_cube_name", broadcast_to_coords=["threshold"] - ) - self.assertIsInstance(result, Cube) - self.assertEqual(result.name(), "probability_of_new_cube_name_above_threshold") - self.assertEqual(result.coord(var_name="threshold").name(), "new_cube_name") - self.assertArrayAlmostEqual(result.data, self.cube4.data) - self.assertCubeListEqual(input_copy, cubelist) - - def test_error_broadcast_coord_wrong_order(self): - """Test that plugin throws an error if the broadcast coord is not on the first cube""" - plugin = CubeCombiner("*") - cube = self.cube4[:, 0, ...].copy() - cube.data = np.ones_like(cube.data) - cube.remove_coord("lwe_thickness_of_precipitation_amount") - cubelist = iris.cube.CubeList([cube, self.cube4.copy()]) - msg = ( - "Cannot find coord threshold in " - " to broadcast to" - ) - with self.assertRaisesRegex(CoordinateNotFoundError, msg): - plugin.process(cubelist, "new_cube_name", broadcast_to_coords=["threshold"]) - - def test_error_broadcast_coord_not_found(self): - """Test that plugin throws an error if the broadcast coord is not present anywhere""" - plugin = CubeCombiner("*") - cube = self.cube4[:, 0, ...].copy() - cube.data = np.ones_like(cube.data) - cubelist = iris.cube.CubeList([self.cube4.copy(), cube]) - msg = ( - "Cannot find coord kittens in " - " " - "to broadcast to." - ) - with self.assertRaisesRegex(CoordinateNotFoundError, msg): - plugin.process(cubelist, "new_cube_name", broadcast_to_coords=["kittens"]) - - def test_error_broadcast_coord_is_auxcoord(self): - """Test that plugin throws an error if the broadcast coord already exists""" - plugin = CubeCombiner("*") - cube = self.cube4[:, 0, ...].copy() - cube.data = np.ones_like(cube.data) - cubelist = iris.cube.CubeList([self.cube4.copy(), cube]) - msg = "Cannot broadcast to coord threshold as it already exists as an AuxCoord" - with self.assertRaisesRegex(TypeError, msg): - plugin.process(cubelist, "new_cube_name", broadcast_to_coords=["threshold"]) - - def test_multiply_preserves_bounds(self): - """Test specific case for precipitation type, where multiplying a - precipitation accumulation by a point-time probability of snow retains - the bounds on the original accumulation.""" - validity_time = datetime(2015, 11, 19, 0) - time_bounds = [datetime(2015, 11, 18, 23), datetime(2015, 11, 19, 0)] - forecast_reference_time = datetime(2015, 11, 18, 22) - precip_accum = set_up_variable_cube( - np.full((2, 3, 3), 1.5, dtype=np.float32), - name="lwe_thickness_of_precipitation_amount", - units="mm", - time=validity_time, - time_bounds=time_bounds, - frt=forecast_reference_time, - ) - snow_prob = set_up_variable_cube( - np.full(precip_accum.shape, 0.2, dtype=np.float32), - name="probability_of_snow", - units="1", - time=validity_time, - frt=forecast_reference_time, - ) - plugin = CubeCombiner("multiply") - result = plugin.process( - [precip_accum, snow_prob], "lwe_thickness_of_snowfall_amount" - ) - self.assertArrayAlmostEqual(result.data, np.full((2, 3, 3), 0.3)) - self.assertArrayEqual(result.coord("time"), precip_accum.coord("time")) - if __name__ == "__main__": unittest.main() diff --git a/improver_tests/cube_combiner/test_CubeMultiplier.py b/improver_tests/cube_combiner/test_CubeMultiplier.py new file mode 100644 index 0000000000..854b800130 --- /dev/null +++ b/improver_tests/cube_combiner/test_CubeMultiplier.py @@ -0,0 +1,132 @@ +# -*- coding: utf-8 -*- +# ----------------------------------------------------------------------------- +# (C) British Crown Copyright 2017-2020 Met Office. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +"""Unit tests for the cube_combiner.CubeMultiplier plugin.""" +import unittest +from copy import deepcopy +from datetime import datetime + +import iris +import numpy as np +from iris.cube import Cube +from iris.exceptions import CoordinateNotFoundError + +from improver.cube_combiner import CubeMultiplier +from improver.synthetic_data.set_up_test_cubes import set_up_variable_cube +from improver_tests.cube_combiner.test_CubeCombiner import CombinerTest + + +class Test_process(CombinerTest): + """Test process method of CubeMultiplier""" + + def test_broadcast_coord(self): + """Test that plugin broadcasts to threshold coord without changing inputs. + Using the broadcast_to_coords argument including a value of "threshold" + will result in the returned cube maintaining the probabilistic elements + of the name of the first input cube.""" + cube = self.cube4[:, 0, ...].copy() + cube.data = np.ones_like(cube.data) + cube.remove_coord("lwe_thickness_of_precipitation_amount") + cubelist = iris.cube.CubeList([self.cube4.copy(), cube]) + input_copy = deepcopy(cubelist) + result = CubeMultiplier()( + cubelist, "new_cube_name", broadcast_to_threshold=True + ) + self.assertIsInstance(result, Cube) + self.assertEqual(result.name(), "probability_of_new_cube_name_above_threshold") + self.assertEqual(result.coord(var_name="threshold").name(), "new_cube_name") + self.assertArrayAlmostEqual(result.data, self.cube4.data) + self.assertCubeListEqual(input_copy, cubelist) + + def test_error_broadcast_coord_not_found(self): + """Test that plugin throws an error if asked to broadcast to a threshold coord + that is not present on the first cube""" + cube = self.cube4[:, 0, ...].copy() + cube.data = np.ones_like(cube.data) + cube.remove_coord("lwe_thickness_of_precipitation_amount") + cubelist = iris.cube.CubeList([cube, self.cube4.copy()]) + msg = ( + "Cannot find coord threshold in " + " to broadcast to" + ) + with self.assertRaisesRegex(CoordinateNotFoundError, msg): + CubeMultiplier()(cubelist, "new_cube_name", broadcast_to_threshold=True) + + def test_error_broadcast_coord_is_auxcoord(self): + """Test that plugin throws an error if asked to broadcast to a threshold coord + that already exists on later cubes""" + cube = self.cube4[:, 0, ...].copy() + cube.data = np.ones_like(cube.data) + cubelist = iris.cube.CubeList([self.cube4.copy(), cube]) + msg = "Cannot broadcast to coord threshold as it already exists as an AuxCoord" + with self.assertRaisesRegex(TypeError, msg): + CubeMultiplier()(cubelist, "new_cube_name", broadcast_to_threshold=True) + + def test_multiply_preserves_bounds(self): + """Test specific case for precipitation type, where multiplying a + precipitation accumulation by a point-time probability of snow retains + the bounds on the original accumulation.""" + validity_time = datetime(2015, 11, 19, 0) + time_bounds = [datetime(2015, 11, 18, 23), datetime(2015, 11, 19, 0)] + forecast_reference_time = datetime(2015, 11, 18, 22) + precip_accum = set_up_variable_cube( + np.full((2, 3, 3), 1.5, dtype=np.float32), + name="lwe_thickness_of_precipitation_amount", + units="mm", + time=validity_time, + time_bounds=time_bounds, + frt=forecast_reference_time, + ) + snow_prob = set_up_variable_cube( + np.full(precip_accum.shape, 0.2, dtype=np.float32), + name="probability_of_snow", + units="1", + time=validity_time, + frt=forecast_reference_time, + ) + result = CubeMultiplier()( + [precip_accum, snow_prob], "lwe_thickness_of_snowfall_amount", + ) + self.assertArrayAlmostEqual(result.data, np.full((2, 3, 3), 0.3)) + self.assertArrayEqual(result.coord("time"), precip_accum.coord("time")) + + def test_exception_for_single_entry_cubelist(self): + """Test that the plugin raises an exception if a cubelist containing + only one cube is passed in.""" + plugin = CubeMultiplier() + msg = "Expecting 2 or more cubes in cube_list" + cubelist = iris.cube.CubeList([self.cube1]) + with self.assertRaisesRegex(ValueError, msg): + plugin.process(cubelist, "new_cube_name") + + +if __name__ == "__main__": + unittest.main()