Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added ScaleProtocol to make it possible to pass Scale objects that are not directly imported from dinosaur.scales. #57

Merged
merged 1 commit into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions dinosaur/primitive_equations.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,8 @@ class PrimitiveEquationsSpecs:
water_vapor_gas_constant: the non-dimensionalized gas constant for vapor.
water_vapor_isobaric_heat_capacity: isobaric heat capacity of vapor.
kappa: `ideal_gas_constant / Cp` where Cp is the isobaric heat capacity.
scale: an instance of `Scale` that will be used to (non-)dimensionalize
quantities.
scale: an instance implementing `ScaleProtocol` that will be used to
(non-)dimensionalize quantities.
"""

radius: float
Expand All @@ -287,7 +287,7 @@ class PrimitiveEquationsSpecs:
water_vapor_gas_constant: float
water_vapor_isobaric_heat_capacity: float
kappa: float
scale: scales.Scale
scale: scales.ScaleProtocol

@property
def R(self) -> float:
Expand Down Expand Up @@ -348,7 +348,7 @@ def from_si(
water_vapor_gas_constant_si: Quantity = scales.IDEAL_GAS_CONSTANT_H20,
water_vapor_isobaric_heat_capacity_si: Quantity = scales.WATER_VAPOR_CP,
kappa_si: Quantity = scales.KAPPA,
scale: scales.Scale = scales.DEFAULT_SCALE,
scale: scales.ScaleProtocol = scales.DEFAULT_SCALE,
) -> PrimitiveEquationsSpecs:
"""Constructs `PrimitiveEquantionSpecs` from SI constants."""
return cls(
Expand Down
5 changes: 2 additions & 3 deletions dinosaur/primitive_equations_integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

"""Integration tests for primitive_equations."""
import dataclasses
import functools

from absl.testing import absltest
Expand All @@ -28,7 +27,7 @@
from dinosaur import time_integration
from dinosaur import xarray_utils
import jax
from jax import config
from jax import config # pylint: disable=g-importing-member
import jax.numpy as jnp
import numpy as np

Expand Down Expand Up @@ -226,7 +225,7 @@ def test_real_vs_fast_spherical_harmonics(self):

with self.subTest('evolved state'):
assert_states_close(
real_out_nodal, fast_out_nodal, rtol=1e-5, range_tol=1e-5
real_out_nodal, fast_out_nodal, rtol=1e-4, range_tol=1e-5
)


Expand Down
14 changes: 13 additions & 1 deletion dinosaur/scales.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Code for describing units and non-dimensionalizing quantities."""

from collections import abc
from typing import Iterator
from typing import Iterator, Protocol

import jax.numpy as jnp
import numpy as np
Expand Down Expand Up @@ -104,6 +104,18 @@ def _get_dimension(quantity: Quantity) -> str:
return str(quantity.dimensionality)


class ScaleProtocol(Protocol):
"""A protocol class for `Scale` objects that perform nondimensionalization."""

def nondimensionalize(self, quantity: Quantity) -> Numeric:
"""Converts a `pint.Quantity` to a non-dimensional value."""
...

def dimensionalize(self, value: Numeric, unit: Unit) -> Quantity:
"""Converts non-dimensional `value` to a `pint.Quantity` with `unit`."""
...


class Scale(abc.Mapping):
"""A `Scale` converts values to and from dimensionless quantities."""

Expand Down
8 changes: 4 additions & 4 deletions dinosaur/shallow_water.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,14 @@ class ShallowWaterSpecs:
domain.
gravity_acceleration: the non-dimensionalized value of gravitational
acceleration.
scale: an instance of `Scale` that will be used to (non-)dimensionalize
quantities.
scale: an instance implementing `ScaleProtocol` that will be used to
(non-)dimensionalize quantities.
"""
densities: Array
radius: float
angular_velocity: float
gravity_acceleration: float
scale: scales.Scale
scale: scales.ScaleProtocol

@property
def g(self) -> float:
Expand All @@ -110,7 +110,7 @@ def from_si(
radius_si: Quantity = scales.RADIUS,
angular_velocity_si: Quantity = scales.ANGULAR_VELOCITY,
gravity_acceleration_si: Quantity = scales.GRAVITY_ACCELERATION,
scale: scales.Scale = scales.DEFAULT_SCALE) -> ShallowWaterSpecs:
scale: scales.ScaleProtocol = scales.DEFAULT_SCALE) -> ShallowWaterSpecs:
"""Constructs `ShallowWaterSpecs` from SI constants."""
return cls(scale.nondimensionalize(densities),
scale.nondimensionalize(radius_si),
Expand Down
Loading