From 385f14db9d4f11a1e0ba6e9fe3a37e4f7cbe5574 Mon Sep 17 00:00:00 2001 From: Gaute Hope Date: Wed, 27 Sep 2023 12:10:40 +0200 Subject: [PATCH] Use the env class, add shortcuts for user-facing methods --- opendrift/models/basemodel/__init__.py | 31 +++++++++++++---------- opendrift/models/basemodel/environment.py | 24 +++++++++++++----- opendrift/models/chemicaldrift.py | 14 +++++----- opendrift/models/openberg.py | 12 ++++----- tests/models/test_chemicaldrift.py | 8 +++--- 5 files changed, 51 insertions(+), 38 deletions(-) diff --git a/opendrift/models/basemodel/__init__.py b/opendrift/models/basemodel/__init__.py index 246801afe..17b80dd24 100644 --- a/opendrift/models/basemodel/__init__.py +++ b/opendrift/models/basemodel/__init__.py @@ -22,7 +22,7 @@ import inspect import logging -from opendrift.models.basemodel.environment import Environment +from opendrift.models.basemodel.environment import Environment, HasEnvironment logging.captureWarnings(True) logger = logging.getLogger(__name__) @@ -55,7 +55,7 @@ from opendrift.config import Configurable, CONFIG_LEVEL_ESSENTIAL, CONFIG_LEVEL_BASIC, CONFIG_LEVEL_ADVANCED -class OpenDriftSimulation(PhysicsMethods, Timeable, Configurable): +class OpenDriftSimulation(PhysicsMethods, Timeable, Configurable, HasEnvironment): """Generic trajectory model class, to be extended (subclassed). This as an Abstract Base Class, meaning that only subclasses can @@ -114,7 +114,6 @@ class OpenDriftSimulation(PhysicsMethods, Timeable, Configurable): } max_speed = 1 # Assumed max average speed of any element - required_profiles_z_range = None # [min_depth, max_depth] plot_comparison_colors = [ 'k', 'r', 'g', 'b', 'm', 'c', 'y', 'crimson', 'indigo', 'lightcoral', 'grey', 'sandybrown', 'palegreen', 'gold', 'yellowgreen', 'lime', @@ -172,7 +171,7 @@ def __init__(self, # List to store GeoJSON dicts of seeding commands self.seed_geojson = [] - self.env = Environment(self.required_variables, self._config) + self.env = Environment(self.required_variables, self.required_profiles_z_range, self.max_speed, self._config) # Make copies of dictionaries so that they are private to each instance self.status_categories = ['active'] # Particles are active by default @@ -698,6 +697,10 @@ def ElementType(self): def required_variables(self): """Any trajectory model implementation must list needed variables.""" + @abstractproperty + def required_profiles_z_range(self): + """Any trajectory model implementation must list range or return None.""" + def test_data_folder(self): import opendrift return os.path.abspath( @@ -709,8 +712,8 @@ def performance(self): outStr = '--------------------\n' outStr += 'Reader performance:\n' - for r in self.readers: - reader = self.readers[r] + for r in self.env.readers: + reader = self.env.readers[r] if reader.is_lazy: continue outStr += '--------------------\n' @@ -1811,7 +1814,7 @@ def run(self, if steps is not None: duration = steps * self.time_step else: - for reader in self.readers.values(): + for reader in self.env.readers.values(): if reader.end_time is not None: if end_time is None: end_time = reader.end_time @@ -2101,11 +2104,11 @@ def run(self, else: readers = self.env.priority_list[var] if readers[0].startswith( - 'constant_reader') and var in self.readers[ + 'constant_reader') and var in self.env.readers[ readers[0]]._parameter_value_map: self.add_metadata( keyword, - self.readers[readers[0]]._parameter_value_map[var][0]) + self.env.readers[readers[0]]._parameter_value_map[var][0]) else: self.add_metadata(keyword, self.env.priority_list[var]) @@ -2427,7 +2430,7 @@ def set_up_map(self, if 'land_binary_mask' in self.env.priority_list and self.env.priority_list[ 'land_binary_mask'][0] == 'shape': logger.debug('Using custom shapes for plotting land..') - ax.add_geometries(self.readers['shape'].polys, + ax.add_geometries(self.env.readers['shape'].polys, ccrs.PlateCarree(globe=globe), facecolor=land_color, edgecolor='black') @@ -3699,8 +3702,8 @@ def get_map_background(self, ax, background, crs, time=None): variable = background[0] # A vector is requested else: variable = background # A scalar is requested - for readerName in self.readers: - reader = self.readers[readerName] + for readerName in self.env.readers: + reader = self.env.readers[readerName] if variable in reader.variables: if time is None or reader.start_time is None or ( time >= reader.start_time @@ -4398,7 +4401,7 @@ def __repr__(self): outStr += ' ' + variable + '\n' lazy_readers = [ - r for r in self.readers if self.readers[r].is_lazy is True + r for r in self.env.readers if self.env.readers[r].is_lazy is True ] if len(lazy_readers) > 0: outStr += '---\nLazy readers:\n' @@ -4537,7 +4540,7 @@ def calculate_ftle(self, if reader is None: logger.info('No reader provided, using first available:') - reader = list(self.readers.items())[0][1] + reader = list(self.env.readers.items())[0][1] logger.info(reader.name) if isinstance(reader, pyproj.Proj): proj = reader diff --git a/opendrift/models/basemodel/environment.py b/opendrift/models/basemodel/environment.py index ca252414e..0effa01f3 100644 --- a/opendrift/models/basemodel/environment.py +++ b/opendrift/models/basemodel/environment.py @@ -1,5 +1,5 @@ import logging -from typing import OrderedDict, Dict +from typing import OrderedDict, Dict, List import copy import traceback import numpy as np @@ -20,13 +20,12 @@ class Environment(Timeable, Configurable): readers: OrderedDict priority_list: OrderedDict required_variables: Dict - required_profiles_z_range = None # [min_depth, max_depth] - - max_speed = 1.0 + required_profiles_z_range: List[float] # [min_depth, max_depth] + max_speed: float proj_latlon = pyproj.Proj('+proj=latlong') - def __init__(self, required_variables, _config): + def __init__(self, required_variables, required_profiles_z_range, max_speed, _config): super().__init__() self.fallback_values = {} @@ -34,7 +33,9 @@ def __init__(self, required_variables, _config): self.priority_list = OrderedDict() self.required_variables = required_variables - self._config = _config + self.required_profiles_z_range = required_profiles_z_range + self.max_speed = max_speed + self._config = _config # reference to simulation config # Find variables which require profiles self.required_profiles = [ @@ -50,7 +51,7 @@ def __init__(self, required_variables, _config): and self.required_variables[var]['important'] is False ] - def finalize(self, simulation: Configurable, simulation_extent): + def finalize(self, simulation: 'OpenDriftSimulation', simulation_extent): """ Prepare environment for simulation. """ @@ -821,3 +822,12 @@ def get_environment(self, variables, time, lon, lat, z, profiles): self.timer_end('main loop:readers') return env.view(np.recarray), env_profiles, missing + +class HasEnvironment: + """ + A class that has an `Environment`. Some shortcuts for dealing with readers are provided to the inner `env` instance. + """ + env: Environment + + def add_reader(self, readers, variables=None, first=False): + self.env.add_reader(readers, variables, first) diff --git a/opendrift/models/chemicaldrift.py b/opendrift/models/chemicaldrift.py index 3e39e831d..2f0a9935a 100644 --- a/opendrift/models/chemicaldrift.py +++ b/opendrift/models/chemicaldrift.py @@ -362,13 +362,13 @@ def prepare_run(self): logger.info('Transfer rates:\n %s' % self.transfer_rates) self.SPM_vertical_levels_given = False - for key, value in self.readers.items(): + for key, value in self.env.readers.items(): if 'spm' in value.variables: if (hasattr(value,'sigma') or hasattr(value,'z') ): self.SPM_vertical_levels_given = True self.DOC_vertical_levels_given = False - for key, value in self.readers.items(): + for key, value in self.env.readers.items(): if 'doc' in value.variables: if (hasattr(value,'sigma') or hasattr(value,'z') ): self.DOC_vertical_levels_given = True @@ -1930,14 +1930,14 @@ def write_netcdf_chemical_density_map(self, filename, pixelsize_m='auto', zlevel from netCDF4 import Dataset, date2num #, stringtochar if landmask_shapefile is not None: - if 'shape' in self.readers.keys(): + if 'shape' in self.env.readers.keys(): # removing previously stored landmask - del self.readers['shape'] + del self.env.readers['shape'] # Adding new landmask from opendrift.readers import reader_shape custom_landmask = reader_shape.Reader.from_shpfiles(landmask_shapefile) self.add_reader(custom_landmask) - elif 'global_landmask' not in self.readers.keys(): + elif 'global_landmask' not in self.env.readers.keys(): from opendrift.readers import reader_global_landmask global_landmask = reader_global_landmask.Reader() self.add_reader(global_landmask) @@ -2024,9 +2024,9 @@ def write_netcdf_chemical_density_map(self, filename, pixelsize_m='auto', zlevel landmask = np.zeros_like(H[0,0,0,:,:]) if landmask_shapefile is not None: - landmask = self.readers['shape'].__on_land__(lon_array,lat_array) + landmask = self.env.readers['shape'].__on_land__(lon_array,lat_array) else: - landmask = self.readers['global_landmask'].__on_land__(lon_array,lat_array) + landmask = self.env.readers['global_landmask'].__on_land__(lon_array,lat_array) Landmask=np.zeros_like(H) for zi in range(len(z_array)-1): for sp in range(self.nspecies): diff --git a/opendrift/models/openberg.py b/opendrift/models/openberg.py index c94bd2c72..55e853415 100644 --- a/opendrift/models/openberg.py +++ b/opendrift/models/openberg.py @@ -104,6 +104,8 @@ class OpenBerg(OpenDriftSimulation): 'land_binary_mask': {'fallback': None}, } + required_profiles_z_range = [-120, 0] # [min_depth, max_depth] + # Default colors for plotting status_colors = {'initial': 'green', 'active': 'blue', 'missing_data': 'gray', 'stranded': 'red'} @@ -116,8 +118,6 @@ def __init__(self, d=None, label=None, *args, **kwargs): #self.required_profiles = ['x_sea_water_velocity', # 'y_sea_water_velocity'] # Get vertical current profiles - self.required_profiles_z_range = [-120, 0] # [min_depth, max_depth] - # Calling general constructor of parent class super(OpenBerg, self).__init__(*args, **kwargs) @@ -183,7 +183,7 @@ def prepare_run(self): """ # Retrieve profile provided in z dimension by reader variable_groups, reader_groups, missing_variables = \ - self.get_reader_groups(['x_sea_water_velocity','y_sea_water_velocity']) + self.env.get_reader_groups(['x_sea_water_velocity','y_sea_water_velocity']) if len(reader_groups) == 0: # No current data -> fallback values used @@ -192,10 +192,10 @@ def prepare_run(self): # Obtain depth levels from reader: reader_name = reader_groups[0][0] - if hasattr(self.readers[reader_name], 'z'): - profile = np.abs(np.ma.filled(self.readers[reader_name].z)) + if hasattr(self.env.readers[reader_name], 'z'): + profile = np.abs(np.ma.filled(self.env.readers[reader_name].z)) else: # ROMS sigma levels - profile = np.abs(np.ma.filled(self.readers[reader_name].zlevels)) + profile = np.abs(np.ma.filled(self.env.readers[reader_name].zlevels)) # If current data is missing in at least one dimension, no weighting is performed: if len(missing_variables) > 0: diff --git a/tests/models/test_chemicaldrift.py b/tests/models/test_chemicaldrift.py index a5c1fa04d..c5bbc225a 100644 --- a/tests/models/test_chemicaldrift.py +++ b/tests/models/test_chemicaldrift.py @@ -32,7 +32,7 @@ def test_chemicaldrift_partitioning_organics(): days=300 ntraj=200 - + # If one parameter is given as a list, different instances of # ChemicalDrift will be created configs = [ @@ -49,14 +49,14 @@ def test_chemicaldrift_partitioning_organics(): [ "chemical:particle_diameter_uncertainty" , 0 ], [ "drift:vertical_mixing" , False ], ] - + # possible to specify lists of different constant readers values # that will be used in different longitude intervals readers = [ [ "spm" , 80 ], [ "x_sea_water_velocity" , 0 ], [ "y_sea_water_velocity" , 0 ]] - + # Build list of ChemicalDrift objecs with different parameters o=list() for i in range(len(configs)): @@ -101,4 +101,4 @@ def test_chemicaldrift_partitioning_organics(): assert sum(o[0].elements.specie==2)/o[0].num_elements_total()*100 == 35 assert sum(o[1].elements.specie==2)/o[1].num_elements_total()*100 == 72.5 - assert sum(o[2].elements.specie==2)/o[2].num_elements_total()*100 == 84 \ No newline at end of file + assert sum(o[2].elements.specie==2)/o[2].num_elements_total()*100 == 84