Skip to content

Commit

Permalink
Added ScaleProtocol to make it possible to pass Scale objects that ar…
Browse files Browse the repository at this point in the history
…e not directly imported from dinosaur.scales.

Adjusted rtol in primitive equations evolved state integration test.

PiperOrigin-RevId: 700019135
  • Loading branch information
kochkov92 authored and Dinosaur authors committed Nov 25, 2024
1 parent 11981a1 commit 4fd3338
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 12 deletions.
8 changes: 4 additions & 4 deletions dinosaur/primitive_equations.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,8 @@ class PrimitiveEquationsSpecs:
water_vapor_gas_constant: the non-dimensionalized gas constant for vapor.
water_vapor_isobaric_heat_capacity: isobaric heat capacity of vapor.
kappa: `ideal_gas_constant / Cp` where Cp is the isobaric heat capacity.
scale: an instance of `Scale` that will be used to (non-)dimensionalize
quantities.
scale: an instance implementing `ScaleProtocol` that will be used to
(non-)dimensionalize quantities.
"""

radius: float
Expand All @@ -287,7 +287,7 @@ class PrimitiveEquationsSpecs:
water_vapor_gas_constant: float
water_vapor_isobaric_heat_capacity: float
kappa: float
scale: scales.Scale
scale: scales.ScaleProtocol

@property
def R(self) -> float:
Expand Down Expand Up @@ -348,7 +348,7 @@ def from_si(
water_vapor_gas_constant_si: Quantity = scales.IDEAL_GAS_CONSTANT_H20,
water_vapor_isobaric_heat_capacity_si: Quantity = scales.WATER_VAPOR_CP,
kappa_si: Quantity = scales.KAPPA,
scale: scales.Scale = scales.DEFAULT_SCALE,
scale: scales.ScaleProtocol = scales.DEFAULT_SCALE,
) -> PrimitiveEquationsSpecs:
"""Constructs `PrimitiveEquantionSpecs` from SI constants."""
return cls(
Expand Down
5 changes: 2 additions & 3 deletions dinosaur/primitive_equations_integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

"""Integration tests for primitive_equations."""
import dataclasses
import functools

from absl.testing import absltest
Expand All @@ -28,7 +27,7 @@
from dinosaur import time_integration
from dinosaur import xarray_utils
import jax
from jax import config
from jax import config # pylint: disable=g-importing-member
import jax.numpy as jnp
import numpy as np

Expand Down Expand Up @@ -226,7 +225,7 @@ def test_real_vs_fast_spherical_harmonics(self):

with self.subTest('evolved state'):
assert_states_close(
real_out_nodal, fast_out_nodal, rtol=1e-5, range_tol=1e-5
real_out_nodal, fast_out_nodal, rtol=1e-4, range_tol=1e-5
)


Expand Down
14 changes: 13 additions & 1 deletion dinosaur/scales.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Code for describing units and non-dimensionalizing quantities."""

from collections import abc
from typing import Iterator
from typing import Iterator, Protocol

import jax.numpy as jnp
import numpy as np
Expand Down Expand Up @@ -104,6 +104,18 @@ def _get_dimension(quantity: Quantity) -> str:
return str(quantity.dimensionality)


class ScaleProtocol(Protocol):
"""A protocol class for `Scale` objects that perform nondimensionalization."""

def nondimensionalize(self, quantity: Quantity) -> Numeric:
"""Converts a `pint.Quantity` to a non-dimensional value."""
...

def dimensionalize(self, value: Numeric, unit: Unit) -> Quantity:
"""Converts non-dimensional `value` to a `pint.Quantity` with `unit`."""
...


class Scale(abc.Mapping):
"""A `Scale` converts values to and from dimensionless quantities."""

Expand Down
8 changes: 4 additions & 4 deletions dinosaur/shallow_water.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,14 @@ class ShallowWaterSpecs:
domain.
gravity_acceleration: the non-dimensionalized value of gravitational
acceleration.
scale: an instance of `Scale` that will be used to (non-)dimensionalize
quantities.
scale: an instance implementing `ScaleProtocol` that will be used to
(non-)dimensionalize quantities.
"""
densities: Array
radius: float
angular_velocity: float
gravity_acceleration: float
scale: scales.Scale
scale: scales.ScaleProtocol

@property
def g(self) -> float:
Expand All @@ -110,7 +110,7 @@ def from_si(
radius_si: Quantity = scales.RADIUS,
angular_velocity_si: Quantity = scales.ANGULAR_VELOCITY,
gravity_acceleration_si: Quantity = scales.GRAVITY_ACCELERATION,
scale: scales.Scale = scales.DEFAULT_SCALE) -> ShallowWaterSpecs:
scale: scales.ScaleProtocol = scales.DEFAULT_SCALE) -> ShallowWaterSpecs:
"""Constructs `ShallowWaterSpecs` from SI constants."""
return cls(scale.nondimensionalize(densities),
scale.nondimensionalize(radius_si),
Expand Down

0 comments on commit 4fd3338

Please sign in to comment.