Skip to content

Commit

Permalink
Use the env class, add shortcuts for user-facing methods
Browse files Browse the repository at this point in the history
  • Loading branch information
gauteh committed Sep 27, 2023
1 parent 5c66133 commit 385f14d
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 38 deletions.
31 changes: 17 additions & 14 deletions opendrift/models/basemodel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import inspect
import logging

from opendrift.models.basemodel.environment import Environment
from opendrift.models.basemodel.environment import Environment, HasEnvironment

logging.captureWarnings(True)
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -55,7 +55,7 @@
from opendrift.config import Configurable, CONFIG_LEVEL_ESSENTIAL, CONFIG_LEVEL_BASIC, CONFIG_LEVEL_ADVANCED


class OpenDriftSimulation(PhysicsMethods, Timeable, Configurable):
class OpenDriftSimulation(PhysicsMethods, Timeable, Configurable, HasEnvironment):
"""Generic trajectory model class, to be extended (subclassed).
This as an Abstract Base Class, meaning that only subclasses can
Expand Down Expand Up @@ -114,7 +114,6 @@ class OpenDriftSimulation(PhysicsMethods, Timeable, Configurable):
}

max_speed = 1 # Assumed max average speed of any element
required_profiles_z_range = None # [min_depth, max_depth]
plot_comparison_colors = [
'k', 'r', 'g', 'b', 'm', 'c', 'y', 'crimson', 'indigo', 'lightcoral',
'grey', 'sandybrown', 'palegreen', 'gold', 'yellowgreen', 'lime',
Expand Down Expand Up @@ -172,7 +171,7 @@ def __init__(self,
# List to store GeoJSON dicts of seeding commands
self.seed_geojson = []

self.env = Environment(self.required_variables, self._config)
self.env = Environment(self.required_variables, self.required_profiles_z_range, self.max_speed, self._config)

# Make copies of dictionaries so that they are private to each instance
self.status_categories = ['active'] # Particles are active by default
Expand Down Expand Up @@ -698,6 +697,10 @@ def ElementType(self):
def required_variables(self):
"""Any trajectory model implementation must list needed variables."""

@abstractproperty
def required_profiles_z_range(self):
"""Any trajectory model implementation must list range or return None."""

def test_data_folder(self):
import opendrift
return os.path.abspath(
Expand All @@ -709,8 +712,8 @@ def performance(self):

outStr = '--------------------\n'
outStr += 'Reader performance:\n'
for r in self.readers:
reader = self.readers[r]
for r in self.env.readers:
reader = self.env.readers[r]
if reader.is_lazy:
continue
outStr += '--------------------\n'
Expand Down Expand Up @@ -1811,7 +1814,7 @@ def run(self,
if steps is not None:
duration = steps * self.time_step
else:
for reader in self.readers.values():
for reader in self.env.readers.values():
if reader.end_time is not None:
if end_time is None:
end_time = reader.end_time
Expand Down Expand Up @@ -2101,11 +2104,11 @@ def run(self,
else:
readers = self.env.priority_list[var]
if readers[0].startswith(
'constant_reader') and var in self.readers[
'constant_reader') and var in self.env.readers[
readers[0]]._parameter_value_map:
self.add_metadata(
keyword,
self.readers[readers[0]]._parameter_value_map[var][0])
self.env.readers[readers[0]]._parameter_value_map[var][0])
else:
self.add_metadata(keyword, self.env.priority_list[var])

Expand Down Expand Up @@ -2427,7 +2430,7 @@ def set_up_map(self,
if 'land_binary_mask' in self.env.priority_list and self.env.priority_list[
'land_binary_mask'][0] == 'shape':
logger.debug('Using custom shapes for plotting land..')
ax.add_geometries(self.readers['shape'].polys,
ax.add_geometries(self.env.readers['shape'].polys,
ccrs.PlateCarree(globe=globe),
facecolor=land_color,
edgecolor='black')
Expand Down Expand Up @@ -3699,8 +3702,8 @@ def get_map_background(self, ax, background, crs, time=None):
variable = background[0] # A vector is requested
else:
variable = background # A scalar is requested
for readerName in self.readers:
reader = self.readers[readerName]
for readerName in self.env.readers:
reader = self.env.readers[readerName]
if variable in reader.variables:
if time is None or reader.start_time is None or (
time >= reader.start_time
Expand Down Expand Up @@ -4398,7 +4401,7 @@ def __repr__(self):
outStr += ' ' + variable + '\n'

lazy_readers = [
r for r in self.readers if self.readers[r].is_lazy is True
r for r in self.env.readers if self.env.readers[r].is_lazy is True
]
if len(lazy_readers) > 0:
outStr += '---\nLazy readers:\n'
Expand Down Expand Up @@ -4537,7 +4540,7 @@ def calculate_ftle(self,

if reader is None:
logger.info('No reader provided, using first available:')
reader = list(self.readers.items())[0][1]
reader = list(self.env.readers.items())[0][1]
logger.info(reader.name)
if isinstance(reader, pyproj.Proj):
proj = reader
Expand Down
24 changes: 17 additions & 7 deletions opendrift/models/basemodel/environment.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import OrderedDict, Dict
from typing import OrderedDict, Dict, List
import copy
import traceback
import numpy as np
Expand All @@ -20,21 +20,22 @@ class Environment(Timeable, Configurable):
readers: OrderedDict
priority_list: OrderedDict
required_variables: Dict
required_profiles_z_range = None # [min_depth, max_depth]

max_speed = 1.0
required_profiles_z_range: List[float] # [min_depth, max_depth]
max_speed: float

proj_latlon = pyproj.Proj('+proj=latlong')

def __init__(self, required_variables, _config):
def __init__(self, required_variables, required_profiles_z_range, max_speed, _config):
super().__init__()

self.fallback_values = {}
self.readers = OrderedDict()
self.priority_list = OrderedDict()

self.required_variables = required_variables
self._config = _config
self.required_profiles_z_range = required_profiles_z_range
self.max_speed = max_speed
self._config = _config # reference to simulation config

# Find variables which require profiles
self.required_profiles = [
Expand All @@ -50,7 +51,7 @@ def __init__(self, required_variables, _config):
and self.required_variables[var]['important'] is False
]

def finalize(self, simulation: Configurable, simulation_extent):
def finalize(self, simulation: 'OpenDriftSimulation', simulation_extent):
"""
Prepare environment for simulation.
"""
Expand Down Expand Up @@ -821,3 +822,12 @@ def get_environment(self, variables, time, lon, lat, z, profiles):
self.timer_end('main loop:readers')

return env.view(np.recarray), env_profiles, missing

class HasEnvironment:
"""
A class that has an `Environment`. Some shortcuts for dealing with readers are provided to the inner `env` instance.
"""
env: Environment

def add_reader(self, readers, variables=None, first=False):
self.env.add_reader(readers, variables, first)
14 changes: 7 additions & 7 deletions opendrift/models/chemicaldrift.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,13 +362,13 @@ def prepare_run(self):
logger.info('Transfer rates:\n %s' % self.transfer_rates)

self.SPM_vertical_levels_given = False
for key, value in self.readers.items():
for key, value in self.env.readers.items():
if 'spm' in value.variables:
if (hasattr(value,'sigma') or hasattr(value,'z') ):
self.SPM_vertical_levels_given = True

self.DOC_vertical_levels_given = False
for key, value in self.readers.items():
for key, value in self.env.readers.items():
if 'doc' in value.variables:
if (hasattr(value,'sigma') or hasattr(value,'z') ):
self.DOC_vertical_levels_given = True
Expand Down Expand Up @@ -1930,14 +1930,14 @@ def write_netcdf_chemical_density_map(self, filename, pixelsize_m='auto', zlevel
from netCDF4 import Dataset, date2num #, stringtochar

if landmask_shapefile is not None:
if 'shape' in self.readers.keys():
if 'shape' in self.env.readers.keys():
# removing previously stored landmask
del self.readers['shape']
del self.env.readers['shape']
# Adding new landmask
from opendrift.readers import reader_shape
custom_landmask = reader_shape.Reader.from_shpfiles(landmask_shapefile)
self.add_reader(custom_landmask)
elif 'global_landmask' not in self.readers.keys():
elif 'global_landmask' not in self.env.readers.keys():
from opendrift.readers import reader_global_landmask
global_landmask = reader_global_landmask.Reader()
self.add_reader(global_landmask)
Expand Down Expand Up @@ -2024,9 +2024,9 @@ def write_netcdf_chemical_density_map(self, filename, pixelsize_m='auto', zlevel

landmask = np.zeros_like(H[0,0,0,:,:])
if landmask_shapefile is not None:
landmask = self.readers['shape'].__on_land__(lon_array,lat_array)
landmask = self.env.readers['shape'].__on_land__(lon_array,lat_array)
else:
landmask = self.readers['global_landmask'].__on_land__(lon_array,lat_array)
landmask = self.env.readers['global_landmask'].__on_land__(lon_array,lat_array)
Landmask=np.zeros_like(H)
for zi in range(len(z_array)-1):
for sp in range(self.nspecies):
Expand Down
12 changes: 6 additions & 6 deletions opendrift/models/openberg.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ class OpenBerg(OpenDriftSimulation):
'land_binary_mask': {'fallback': None},
}

required_profiles_z_range = [-120, 0] # [min_depth, max_depth]

# Default colors for plotting
status_colors = {'initial': 'green', 'active': 'blue',
'missing_data': 'gray', 'stranded': 'red'}
Expand All @@ -116,8 +118,6 @@ def __init__(self, d=None, label=None, *args, **kwargs):
#self.required_profiles = ['x_sea_water_velocity',
# 'y_sea_water_velocity'] # Get vertical current profiles

self.required_profiles_z_range = [-120, 0] # [min_depth, max_depth]

# Calling general constructor of parent class
super(OpenBerg, self).__init__(*args, **kwargs)

Expand Down Expand Up @@ -183,7 +183,7 @@ def prepare_run(self):
"""
# Retrieve profile provided in z dimension by reader
variable_groups, reader_groups, missing_variables = \
self.get_reader_groups(['x_sea_water_velocity','y_sea_water_velocity'])
self.env.get_reader_groups(['x_sea_water_velocity','y_sea_water_velocity'])

if len(reader_groups) == 0:
# No current data -> fallback values used
Expand All @@ -192,10 +192,10 @@ def prepare_run(self):

# Obtain depth levels from reader:
reader_name = reader_groups[0][0]
if hasattr(self.readers[reader_name], 'z'):
profile = np.abs(np.ma.filled(self.readers[reader_name].z))
if hasattr(self.env.readers[reader_name], 'z'):
profile = np.abs(np.ma.filled(self.env.readers[reader_name].z))
else: # ROMS sigma levels
profile = np.abs(np.ma.filled(self.readers[reader_name].zlevels))
profile = np.abs(np.ma.filled(self.env.readers[reader_name].zlevels))

# If current data is missing in at least one dimension, no weighting is performed:
if len(missing_variables) > 0:
Expand Down
8 changes: 4 additions & 4 deletions tests/models/test_chemicaldrift.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_chemicaldrift_partitioning_organics():

days=300
ntraj=200

# If one parameter is given as a list, different instances of
# ChemicalDrift will be created
configs = [
Expand All @@ -49,14 +49,14 @@ def test_chemicaldrift_partitioning_organics():
[ "chemical:particle_diameter_uncertainty" , 0 ],
[ "drift:vertical_mixing" , False ],
]

# possible to specify lists of different constant readers values
# that will be used in different longitude intervals
readers = [
[ "spm" , 80 ],
[ "x_sea_water_velocity" , 0 ],
[ "y_sea_water_velocity" , 0 ]]

# Build list of ChemicalDrift objecs with different parameters
o=list()
for i in range(len(configs)):
Expand Down Expand Up @@ -101,4 +101,4 @@ def test_chemicaldrift_partitioning_organics():

assert sum(o[0].elements.specie==2)/o[0].num_elements_total()*100 == 35
assert sum(o[1].elements.specie==2)/o[1].num_elements_total()*100 == 72.5
assert sum(o[2].elements.specie==2)/o[2].num_elements_total()*100 == 84
assert sum(o[2].elements.specie==2)/o[2].num_elements_total()*100 == 84

0 comments on commit 385f14d

Please sign in to comment.