Skip to content

Commit

Permalink
pass var names as kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
rileyhales committed Feb 4, 2024
1 parent 461c6fe commit 02ddafe
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 13 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
*.pyc
example.py
.DS_store
tests/test_results/*.nc
tests/test_results/*.nc
*.egg-info
build
dist
43 changes: 32 additions & 11 deletions basininflow/inflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,11 @@ def create_inflow_file(lsm_data: str,
timestep: datetime.timedelta = None,
cumulative: bool = False,
file_label: str = None,
force_positive_runoff: bool = False, ) -> None:
force_positive_runoff: bool = False,
runoff_var: str = None,
x_var: str = None,
y_var: str = None,
time_var: str = None, ) -> None:
"""
Generate inflow files for use with RAPID.
Expand Down Expand Up @@ -81,6 +85,14 @@ def create_inflow_file(lsm_data: str,
Label to include in the file name for organization purposes.
force_positive_runoff: bool, optional
Set all negative values to zero. Default is False
runoff_var: str, optional
Name of the runoff variable in the LSM data. If None, will search lsm dataset for recognized names and guess
x_var: str, optional
Name of the x variable in the LSM data. If None, will search lsm dataset for recognized names and guess
y_var: str, optional
Name of the y variable in the LSM data. If None, will search lsm dataset for recognized names and guess
time_var: str, optional
Name of the time variable in the LSM data. If None, will search lsm dataset for recognized names and guess
"""
# Ensure that every input file exists
if weight_table is not None and not os.path.exists(weight_table):
Expand Down Expand Up @@ -108,15 +120,24 @@ def create_inflow_file(lsm_data: str,
logging.info(f'Opening LSM files {lsm_data[0] if type(lsm_data) == list else lsm_data}')
with xr.open_mfdataset(lsm_data) as ds:
# Select the variable names
runoff_variable = [x for x in ['ro', 'RO', 'runoff', 'RUNOFF'] if x in ds.variables][0]
lon_variable = [x for x in ['lon', 'longitude', 'LONGITUDE', 'LON'] if x in ds.variables][0]
lat_variable = [x for x in ['lat', 'latitude', 'LATITUDE', 'LAT'] if x in ds.variables][0]
if not runoff_var:
logging.warning('Runoff variable provided. Guessing from default names')
runoff_var = [x for x in ['ro', 'RO', 'runoff', 'RUNOFF'] if x in ds.variables][0]
if not x_var:
logging.warning('X variable not given. Guessing from default names.')
x_var = [x for x in ['x', 'lon', 'longitude', 'LONGITUDE', 'LON'] if x in ds.variables][0]
if not y_var:
logging.warning('Y variable not given. Guessing from default names.')
y_var = [x for x in ['y', 'lat', 'latitude', 'LATITUDE', 'LAT'] if x in ds.variables][0]
if not time_var:
logging.warning('Time variable not given. Guessing from default names.')
time_var = [x for x in ['time', 'TIME', ] if x in ds.variables][0]

# Check that the input table dimensions match the dataset dimensions
# This gets us the shape, while ignoring the time dimension
variable_dims = ds[runoff_variable].dims
dataset_shape = [ds[runoff_variable].shape[variable_dims.index(lat_variable)],
ds[runoff_variable].shape[variable_dims.index(lon_variable)]]
variable_dims = ds[runoff_var].dims
dataset_shape = [ds[runoff_var].shape[variable_dims.index(y_var)],
ds[runoff_var].shape[variable_dims.index(x_var)]]

if weight_table is None:
# find a file in the input_dir which contains f"weight*{dataset_shape[0]}x{dataset_shape[1]}.csv"
Expand Down Expand Up @@ -151,11 +172,11 @@ def create_inflow_file(lsm_data: str,
lat_indices = weight_df['lat_index'].values # - min_lat_idx
lon_indices = weight_df['lon_index'].values # - min_lon_idx

ds = ds[runoff_variable]
ds = ds[runoff_var]

# Get approximate sizes of arrays and check if we have enough memory
out_array_size = ds['time'].shape[0] * sorted_rivid_array.shape[0]
in_array_size = ds['time'].shape[0] * n_wt_rows
out_array_size = ds[time_var].shape[0] * sorted_rivid_array.shape[0]
in_array_size = ds[time_var].shape[0] * n_wt_rows
if ds.ndim == 4:
in_array_size *= 2
total_size = out_array_size + in_array_size
Expand All @@ -175,7 +196,7 @@ def create_inflow_file(lsm_data: str,

# get the time array from the dataset
logging.info('Reading Time values')
datetime_array = ds['time'].to_numpy()
datetime_array = ds[time_var].to_numpy()

logging.info('Reading Runoff values')
if ds.ndim == 3:
Expand Down
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name: inflows
channels:
- defaults
- conda-forge
- defaults
dependencies:
- dask
- netcdf4
Expand Down

0 comments on commit 02ddafe

Please sign in to comment.