Skip to content

Commit

Permalink
It's working?!
Browse files Browse the repository at this point in the history
refactor

Cleanup code

Remove stuff
  • Loading branch information
thorbjoernl committed Jun 5, 2024
1 parent 9e5e822 commit 157357a
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 1 deletion.
63 changes: 63 additions & 0 deletions pyaerocom/aux_var_helpers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import cf_units
import numpy as np
import pandas as pd
import xarray as xr

from pyaerocom import const
from pyaerocom.variable_helpers import get_variable
Expand Down Expand Up @@ -835,3 +837,64 @@ def make_proxy_wetdep_from_O3(data):

data.data_flagged[new_var_name] = flags
return new_var_data


class CalcRollingAverage:
"""
This class implements a callable interface for calculating rolling averages
for variables.
"""

def __init__(self, window, *, min_periods: int = 1, invarname: str, outvarname: str):
"""
Parameters:
-----------
window : pd.Timedelta | int | float
A window for calculating the rolling average. If numeric type it is
assumed to be an interval in hours.
min_periods : int
Minimum observation count required for a valid rolling avg to be
calculated.
"""
if min_periods <= 0:
raise ValueError(f"minobs must be >=1. Provided value is {min_periods}")
self._min_periods = min_periods
self._window = self.sanitize_window(window)
self._invarname = invarname
self._outvarname = outvarname

def sanitize_window(self, window) -> pd.Timedelta:
"""Sanitation logic for the window parameter."""
if isinstance(window, pd.Timedelta):
return window

if isinstance(window, int | float):
return pd.Timedelta(window, "hours")

raise ValueError(f"Unexpected value of time window {window}.")

def __call__(self, ds: xr.Dataset) -> xr.DataArray:
"""Calculates a rolling average based on the configuration provided
in the constructor.
Parameters:
-----------
ds : xr.Dataset
Dataset from which to get required variables.
Returns : xr.DataArray
The calculated rolling avg.
"""
if not self._invarname in ds:
raise ValueError("Variable {self.invarname} is missing in dataset.")

df = ds[[self._invarname, "time"]].to_pandas()

rolling = df.rolling(window=self._window, min_periods=self._min_periods)

ravg = rolling.mean()

df[self._outvarname] = ravg

return df.to_xarray()
1 change: 1 addition & 0 deletions pyaerocom/io/mep/aux_vars.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pandas as pd
import xarray as xr
from geonum import atmosphere as atm

Expand Down
3 changes: 2 additions & 1 deletion pyaerocom/io/mscw_ctm/additional_variables.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import pandas as pd
import xarray as xr
from geonum.atmosphere import T0_STD, p0

from pyaerocom.aux_var_helpers import concx_to_vmrx
from pyaerocom.aux_var_helpers import CalcDailyMax, CalcRollingAverage, concx_to_vmrx
from pyaerocom.molmasses import get_molmass


Expand Down
6 changes: 6 additions & 0 deletions pyaerocom/io/mscw_ctm/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import xarray as xr

from pyaerocom import const
from pyaerocom.aux_var_helpers import CalcDailyMax, CalcRollingAverage
from pyaerocom.exceptions import VarNotAvailableError
from pyaerocom.griddeddata import GriddedData
from pyaerocom.units_helpers import UALIASES
Expand Down Expand Up @@ -106,6 +107,8 @@ class ReadMscwCtm:
"vmro3": ["conco3"],
# For Pollen
# "concpolyol": ["concspores"],
"conco3mda8": ["conco3"],
"conco3mda8max": ["conco3mda8"],
}

# Functions that are used to compute additional variables (i.e. one
Expand Down Expand Up @@ -149,6 +152,9 @@ class ReadMscwCtm:
"concSso2": calc_concSso2,
"vmro3": calc_vmro3,
# "concpolyol": calc_concpolyol,
"conco3mda8": CalcRollingAverage(
window=8, min_periods=6, invarname="conco3", outvarname="conco3mda8"
),
}

#: supported filename masks, placeholder is for frequencies
Expand Down
14 changes: 14 additions & 0 deletions tests/io/mscw_ctm/test_additional_variables.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import numpy as np
import xarray as xr

from pyaerocom.io.mscw_ctm.additional_variables import (
calc_concNhno3,
calc_concNnh3,
Expand Down Expand Up @@ -126,3 +129,14 @@ def test_update_EC_units():

assert (concCecpm25 == concCecpm25_from_func).all()
assert concCecpm25.units == concCecpm25_from_func.units


def test_calc_conco3mda8max():
time = xr.DataArray(xr.date_range("2024-01-01", "2024-12-31", freq="1h"), dims=("time"))
conco3 = xr.DataArray(np.linspace(start=0, stop=100, num=len(time)), dims=("time"))

fdata = xr.Dataset({"time": time, "conco3": conco3})

result = calc_conco3mda8max(fdata)

assert len(result["conco3mda8max"]) == 366

0 comments on commit 157357a

Please sign in to comment.