From 189583169842ee05967e958e059bf6f505431cfb Mon Sep 17 00:00:00 2001 From: Dmitrii Kochkov Date: Mon, 25 Nov 2024 14:41:21 -0800 Subject: [PATCH] Added ScaleProtocol to make it possible to pass Scale objects that are not directly imported from dinosaur.scales. Adjusted rtol in primitive equations evolved state integration test. PiperOrigin-RevId: 700106042 --- dinosaur/primitive_equations.py | 8 ++++---- dinosaur/primitive_equations_integration_test.py | 5 ++--- dinosaur/scales.py | 14 +++++++++++++- dinosaur/shallow_water.py | 8 ++++---- 4 files changed, 23 insertions(+), 12 deletions(-) diff --git a/dinosaur/primitive_equations.py b/dinosaur/primitive_equations.py index bc0ba7c..8d5b070 100644 --- a/dinosaur/primitive_equations.py +++ b/dinosaur/primitive_equations.py @@ -276,8 +276,8 @@ class PrimitiveEquationsSpecs: water_vapor_gas_constant: the non-dimensionalized gas constant for vapor. water_vapor_isobaric_heat_capacity: isobaric heat capacity of vapor. kappa: `ideal_gas_constant / Cp` where Cp is the isobaric heat capacity. - scale: an instance of `Scale` that will be used to (non-)dimensionalize - quantities. + scale: an instance implementing `ScaleProtocol` that will be used to + (non-)dimensionalize quantities. """ radius: float @@ -287,7 +287,7 @@ class PrimitiveEquationsSpecs: water_vapor_gas_constant: float water_vapor_isobaric_heat_capacity: float kappa: float - scale: scales.Scale + scale: scales.ScaleProtocol @property def R(self) -> float: @@ -348,7 +348,7 @@ def from_si( water_vapor_gas_constant_si: Quantity = scales.IDEAL_GAS_CONSTANT_H20, water_vapor_isobaric_heat_capacity_si: Quantity = scales.WATER_VAPOR_CP, kappa_si: Quantity = scales.KAPPA, - scale: scales.Scale = scales.DEFAULT_SCALE, + scale: scales.ScaleProtocol = scales.DEFAULT_SCALE, ) -> PrimitiveEquationsSpecs: """Constructs `PrimitiveEquantionSpecs` from SI constants.""" return cls( diff --git a/dinosaur/primitive_equations_integration_test.py b/dinosaur/primitive_equations_integration_test.py index a7e32d2..ce35a92 100644 --- a/dinosaur/primitive_equations_integration_test.py +++ b/dinosaur/primitive_equations_integration_test.py @@ -13,7 +13,6 @@ # limitations under the License. """Integration tests for primitive_equations.""" -import dataclasses import functools from absl.testing import absltest @@ -28,7 +27,7 @@ from dinosaur import time_integration from dinosaur import xarray_utils import jax -from jax import config +from jax import config # pylint: disable=g-importing-member import jax.numpy as jnp import numpy as np @@ -226,7 +225,7 @@ def test_real_vs_fast_spherical_harmonics(self): with self.subTest('evolved state'): assert_states_close( - real_out_nodal, fast_out_nodal, rtol=1e-5, range_tol=1e-5 + real_out_nodal, fast_out_nodal, rtol=1e-4, range_tol=1e-5 ) diff --git a/dinosaur/scales.py b/dinosaur/scales.py index fa6a89d..70ea0cb 100644 --- a/dinosaur/scales.py +++ b/dinosaur/scales.py @@ -15,7 +15,7 @@ """Code for describing units and non-dimensionalizing quantities.""" from collections import abc -from typing import Iterator +from typing import Iterator, Protocol import jax.numpy as jnp import numpy as np @@ -104,6 +104,18 @@ def _get_dimension(quantity: Quantity) -> str: return str(quantity.dimensionality) +class ScaleProtocol(Protocol): + """A protocol class for `Scale` objects that perform nondimensionalization.""" + + def nondimensionalize(self, quantity: Quantity) -> Numeric: + """Converts a `pint.Quantity` to a non-dimensional value.""" + ... + + def dimensionalize(self, value: Numeric, unit: Unit) -> Quantity: + """Converts non-dimensional `value` to a `pint.Quantity` with `unit`.""" + ... + + class Scale(abc.Mapping): """A `Scale` converts values to and from dimensionless quantities.""" diff --git a/dinosaur/shallow_water.py b/dinosaur/shallow_water.py index 3bf4d10..099fac2 100644 --- a/dinosaur/shallow_water.py +++ b/dinosaur/shallow_water.py @@ -76,14 +76,14 @@ class ShallowWaterSpecs: domain. gravity_acceleration: the non-dimensionalized value of gravitational acceleration. - scale: an instance of `Scale` that will be used to (non-)dimensionalize - quantities. + scale: an instance implementing `ScaleProtocol` that will be used to + (non-)dimensionalize quantities. """ densities: Array radius: float angular_velocity: float gravity_acceleration: float - scale: scales.Scale + scale: scales.ScaleProtocol @property def g(self) -> float: @@ -110,7 +110,7 @@ def from_si( radius_si: Quantity = scales.RADIUS, angular_velocity_si: Quantity = scales.ANGULAR_VELOCITY, gravity_acceleration_si: Quantity = scales.GRAVITY_ACCELERATION, - scale: scales.Scale = scales.DEFAULT_SCALE) -> ShallowWaterSpecs: + scale: scales.ScaleProtocol = scales.DEFAULT_SCALE) -> ShallowWaterSpecs: """Constructs `ShallowWaterSpecs` from SI constants.""" return cls(scale.nondimensionalize(densities), scale.nondimensionalize(radius_si),