Skip to content

sandbox

Damien Irving edited this page Nov 24, 2023 · 4 revisions
process.py

"""Functions for a process based analysis"""

import re
from collections import Counter
import calendar
import datetime


model_df = model_da_stacked.to_dataframe()
model_df = model_df.dropna()

ranked_events = model_df.sort_values(by=['pr'], ascending=False)
ranked_events.head(n=10)


def plot_seasonality(df):
    """Plot the event seasonality"""
    
    event_months = [int(date[5:7]) for date in df['event_time'].values]
    event_years = [int(date[0:4]) for date in df['event_time'].values]

    month_counts = Counter(event_months)
    months = np.arange(1, 13)
    counts = [month_counts[month] for month in months]

    plt.bar(months, counts)
    plt.ylabel('number of events')
    plt.xlabel('month')
    xlabels = [calendar.month_abbr[i] for i in months]
    plt.xticks(months, xlabels)


def get_run(file):
    """Get the model run information from a CMIP6 file path"""

    match = re.search('r.i.p.f.', file)
    try:
        run = match.group()
    except AttributeError:
        match = re.search('r..i.p.f.', file)
        run = match.group()
        
    return run


def get_dcpp_da(infile_list, row, init_year_offset=0, ensemble_reset=False):
    """Get DCPP data for circulation plot"""

    init_year = int(row['init_date'].strftime('%Y')) + init_year_offset
    ensemble_index = int(row['ensemble']) + 1
    end_date = row['event_time']

    runs = list(map(get_run, infile_list))
    ensemble_labels = []
    for run in runs:
        if not run in ensemble_labels:
            ensemble_labels.append(run)
    target = f's{init_date}-{ensmeble_labels[ensemble_index]}'
    target_files = list(filter(lambda x: target in x, infile_list))
    ds = xr.open_mfdataset(target_files)
    da = ds.sel({'time': slice(start_date, end_date)})[var]
    da = xc.units.convert_units_to(da, 'mm d-1')

    pr_da_sum = pr_da.sum('time', keep_attrs=True)
    psl_da_mean = psl_da.mean('time', keep_attrs=True)

    return da
    

def plot_circulation(
    start_date,
    end_date,
    top_n_events,
    color_file_list=None
    color_var=None,
    contour_file_list=None,
    contour_var=None,
):
    """Plot the mean circulation for the n most extreme events"""

    fig = plt.figure(figsize=[10, 17])
    map_proj=ccrs.PlateCarree(central_longitude=180)

    plotnum = 1
    for index, row in ranked_events.head(n=top_n_events).iterrows():
        ax = fig.add_subplot(top_n_events, 1, plotnum, projection=map_proj)
    


        start_datetime = datetime.datetime.strptime(end_date, "%Y-%m-%d") - datetime.timedelta(days=14)
        start_date = start_datetime.strftime("%Y-%m-%d")
        title = f'{start_date} to {end_date} (initialisation: {init_date}, ensemble: {ensemble})'

        color_da = get_da(color_var, model_name, init_date, ensemble)
        contour_da = get_da(contour_var, model_name, init_date, ensemble)

        if color_var:
            if color_var == 'pr':
                levels = [0, 100, 200, 300, 400, 500, 600, 700, 800]
                label = 'total precipitation (mm)'
                cmap = cmocean.cm.rain
                extend = 'max'
            elif color_var == 'ua300':
                levels = [-30, -25, -20, -15, -10, -5, 0, 5, 10, 15, 20, 25, 30]
                label = '300hPa zonal wind'
                cmap='RdBu_r'
                extend = 'both'
            else:
                raise ValueError('Invalid color variable')
            color_da.plot(
                ax=ax,
                transform=ccrs.PlateCarree(),
                cmap=cmap,
                levels=levels,
                extend=extend,
                cbar_kwargs={'label': label},
            )

        if contour_var:
            if contour_var == 'z500':
                levels = np.arange(5000, 6300, 50)
            elif contour_var == 'psl':
                levels = np.arange(900, 1100, 2.5)
            elif contour_var == 'ua300':
                levels = np.arange(15, 60, 5)
            else:
                raise ValueError('Invalid contour variable')
            lines = contour_da.plot.contour(
                ax=ax,
                transform=ccrs.PlateCarree(),
                levels=levels,
                colors=['0.1']
            )
            ax.clabel(lines, colors=['0.1'], manual=False, inline=True)
    
        ax.coastlines()
        ax.set_extent([90, 205, -55, 10], crs=ccrs.PlateCarree())
        ax.gridlines(linestyle='--', draw_labels=True)
        if contour_var:
            ax.set_title(f'Average {contour_var} ({contour_da.units}), {start_date} to {end_date}')
        else:
            ax.set_title(f'{start_date} to {end_date}')
        plotnum += 1
Clone this wiki locally