diff --git a/pyleoclim/core/multiplegeoseries.py b/pyleoclim/core/multiplegeoseries.py index 714fbb75..8b66e57c 100644 --- a/pyleoclim/core/multiplegeoseries.py +++ b/pyleoclim/core/multiplegeoseries.py @@ -85,7 +85,7 @@ def __init__(self, series_list, time_unit=None, label=None): def map(self, marker='archiveType', hue='archiveType', size=None, cmap=None, edgecolor='k', projection='auto', - proj_default=True, crit_dist=5000, colorbar=True, + proj_default=True, crit_dist=5000, colorbar=True,color_scale_type=None, background=True, borders=False, coastline=True,rivers=False, lakes=False, land=True,ocean=True, figsize=None, fig=None, scatter_kwargs=None, gridspec_kwargs=None, legend=True, gridspec_slot=None, lgd_kwargs=None, savefig_settings=None, **kwargs): @@ -180,6 +180,10 @@ def map(self, marker='archiveType', hue='archiveType', size=None, cmap=None, Whether to draw a colorbar on the figure if the data associated with hue are numeric. Default is True. + color_scale_type : str, optional + Setting to "discrete" will force a discrete color scale with a default bin number of max(11, n) where n=number of unique values$^{\frac{1}{2}}$ + Default is None + lgd_kwargs : dict, optional Dictionary of arguments for `matplotlib.pyplot.legend `_. @@ -406,6 +410,33 @@ def pca(self, weights=None,missing='fill-em',tol_em=5e-03, max_em_iter=100,**pca 'gridspec_kwargs': {'width_ratios': [.5, 1,14, 4], 'wspace':-.065}, 'lgd_kwargs':{'bbox_to_anchor':[-.015,1]}}) + + One can also configure how hue information gets displayed: + + .. jupyter-execute:: + + # Dashboard with hue as a legend category + res.modeplot(index=1, size='elevation', map_kwargs= dict(colorbar=False)) + + # Dashboard with discrete colorbar + res.modeplot(index=1, size='elevation', map_kwargs= dict(color_scale_type='discrete')) + + # Dashboard with custom scalar mappable + sm = pyleo.utils.mapping.make_scalar_mappable(cmap='vlag', hue_vect=res.eigvecs[:, 1], n=21,norm_kwargs={'vcenter': -.5}) + res.modeplot(index=1, size='elevation', map_kwargs= dict(scalar_mappable=sm)) + + + and customize the marker variable: + + .. jupyter-execute:: + + # Dashboard with marker set to archiveType + res.modeplot(index=1, marker='archiveType', size='elevation', map_kwargs= dict(colorbar=True)) + + # Dashboard with marker set to archiveType + res.modeplot(index=1, marker='observationType', size='elevation', map_kwargs= dict(colorbar=True)) + + ''' # apply PCA fom parent class pca_res = super().pca(weights=weights,missing=missing,tol_em=tol_em, diff --git a/pyleoclim/core/multivardecomp.py b/pyleoclim/core/multivardecomp.py index d0d5d6e7..8134405e 100644 --- a/pyleoclim/core/multivardecomp.py +++ b/pyleoclim/core/multivardecomp.py @@ -184,8 +184,9 @@ def screeplot(self, figsize=[6, 4], uq='N82', title=None, ax=None, savefig_setti def modeplot(self, index=0, figsize=[8, 8], fig=None, savefig_settings=None,gs=None, title=None, title_kwargs=None, spec_method='mtm', cmap=None, - hue='EOF', marker='archiveType', size=None, scatter_kwargs=None, + hue='EOF', marker=None, size=None, scatter_kwargs=None, flip = False, map_kwargs=None, gridspec_kwargs=None): + ''' Dashboard visualizing the properties of a given mode, including: 1. The temporal coefficient (PC or similar) 2. its spectrum @@ -248,8 +249,9 @@ def modeplot(self, index=0, figsize=[8, 8], fig=None, savefig_settings=None,gs=N - gridspec_kwargs: dict; Optional values for adjusting the arrangement of the colorbar, map and legend in the map subplot - legend: bool; Whether to draw a legend on the figure. Default is True - colorbar: bool; Whether to draw a colorbar on the figure if the data associated with hue are numeric. Default is True - The default is None. - + - color_scale_type : str; Setting to "discrete" will force a discrete color scale with a default bin number of max(11, n) where n=number of unique values$^{\frac{1}{2}}$. Default is None + - scalar_mappable: matplotlib.cm.ScalarMappable; can be used to pass a matplotlib scalar mappable. See pyleoclim.utils.plotting.make_scalar_mappable for documentation on using the Pyleoclim utility, or the `Matplotlib tutorial on customizing colorbars `_. + scatter_kwargs : dict, optional Optional arguments configuring how data are plotted on a map. See description of scatter_kwargs in pyleoclim.utils.mapping.scatter_map @@ -263,7 +265,8 @@ def modeplot(self, index=0, figsize=[8, 8], fig=None, savefig_settings=None,gs=N marker : string, optional (only applicable if using scatter map) Grouping variable that will produce points with different markers. Can have a numeric dtype but will always be treated as categorical. - The default is 'archiveType'. + The default is None, which will produce circle markers. Alternatively, pass the name of a categorical variable, e.g. 'archiveType'. If 'archiveType' is specified, will attempt to use pyleoclim archiveType markers mapping, defaulting to '?' where values are unavailable. + Returns ------- @@ -274,6 +277,7 @@ def modeplot(self, index=0, figsize=[8, 8], fig=None, savefig_settings=None,gs=N ax : dict dictionary of matplotlib ax + See also -------- @@ -284,7 +288,10 @@ def modeplot(self, index=0, figsize=[8, 8], fig=None, savefig_settings=None,gs=N pyleoclim.utils.tsutils.eff_sample_size : Effective sample size pyleoclim.utils.mapping.scatter_map : mapping - + + pyleoclim.utils.plotting.make_scalar_mappable : Custom scalar mappable + + ''' from ..core.multiplegeoseries import MultipleGeoSeries @@ -351,12 +358,24 @@ def modeplot(self, index=0, figsize=[8, 8], fig=None, savefig_settings=None,gs=N map_gridspec_kwargs = map_kwargs.pop('gridspec_kwargs', {}) lgd_kwargs = map_kwargs.pop('lgd_kwargs', {}) + if scatter_kwargs is None: + scatter_kwargs = {} + else: + scatter_kwargs = scatter_kwargs.copy() if 'edgecolor' in map_kwargs.keys(): scatter_kwargs.update({'edgecolor': map_kwargs['edgecolor']}) legend = map_kwargs.pop('legend', True) colorbar = map_kwargs.pop('colorbar', True) + color_scale_type = map_kwargs.pop('color_scale_type', None) + + if marker is None: + marker = scatter_kwargs.pop('marker', None) + if hue is None: + hue = scatter_kwargs.pop('hue', None) + if size is None: + size = scatter_kwargs.pop('size', None) if isinstance(self.orig, MultipleGeoSeries): # This makes a bare bones dataframe from a MultipleGeoSeries object @@ -373,9 +392,9 @@ def modeplot(self, index=0, figsize=[8, 8], fig=None, savefig_settings=None,gs=N rivers=rivers, lakes=lakes, ocean=ocean, land=land, extent=extent, figsize=None, scatter_kwargs=scatter_kwargs, lgd_kwargs=lgd_kwargs, - gridspec_kwargs=map_gridspec_kwargs, colorbar=colorbar, + gridspec_kwargs=map_gridspec_kwargs, colorbar=colorbar, color_scale_type=color_scale_type, legend=legend, cmap=cmap, - fig=fig, gs_slot=gs[-1, :]) #label rf'$EOF_{index + 1}$' + fig=fig, gs_slot=gs[-1, :], **map_kwargs) #label rf'$EOF_{index + 1}$' else: # it must be a plain old MultipleSeries. No map for you! Just a spaghetti plot with the standardizes series ax['map'] = fig.add_subplot(gs[1:, :]) diff --git a/pyleoclim/utils/mapping.py b/pyleoclim/utils/mapping.py index 93e32b79..571d1982 100644 --- a/pyleoclim/utils/mapping.py +++ b/pyleoclim/utils/mapping.py @@ -23,11 +23,10 @@ import matplotlib.gridspec as gridspec from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable -from .plotting import savefig +from .plotting import savefig, make_scalar_mappable, keep_center_colormap, consolidate_legends from .lipdutils import PLOT_DEFAULT, LipdToOntology, CaseInsensitiveDict - def pick_proj(lat, lon, crit_dist=5000): ''' Pick projection based on the degree of clustering of coordinates. @@ -50,11 +49,11 @@ def pick_proj(lat, lon, crit_dist=5000): 'Orthographic' or 'Robinson' ''' - lon = lon_360_to_180(lon) # convert longitudes to [-180, 180] + lon = lon_360_to_180(lon) # convert longitudes to [-180, 180] - lat_c, lon_c = centroid_coords(lat,lon) # find coordinates of centroid + lat_c, lon_c = centroid_coords(lat, lon) # find coordinates of centroid - d = compute_dist(lat_c, lon_c, lat, lon) # computes distances to centroid + d = compute_dist(lat_c, lon_c, lat, lon) # computes distances to centroid dmax = np.array(d).max() # find maximum distance if dmax > crit_dist: proj = 'Robinson' @@ -63,6 +62,7 @@ def pick_proj(lat, lon, crit_dist=5000): return proj + def set_proj(projection='Robinson', proj_default=True): """ Set the projection for Cartopy. @@ -283,7 +283,7 @@ def set_proj(projection='Robinson', proj_default=True): def map(lat, lon, criteria, marker=None, color=None, - projection='auto', proj_default=True, crit_dist = 5000, + projection='auto', proj_default=True, crit_dist=5000, background=True, borders=False, rivers=False, lakes=False, figsize=None, ax=None, scatter_kwargs=None, legend=True, legend_title=None, lgd_kwargs=None, savefig_settings=None): @@ -530,11 +530,10 @@ def make_df(geo_ms, hue=None, marker=None, size=None, cols=None, d=None): lats = [geos.lat for geos in geo_series_list] lons = [geos.lon for geos in geo_series_list] - traits = [hue, marker, size] + traits = [hue, marker, size] if type(cols) == list: traits += cols value_d = {'lat': lats, 'lon': lons} - for trait in traits: # trait_d.keys(): # trait = trait_d[trait_key] @@ -544,19 +543,18 @@ def make_df(geo_ms, hue=None, marker=None, size=None, cols=None, d=None): for geos in geo_series_list: if trait in geos.__dict__.keys(): try: - trait_vals.append(LipdToOntology(geos.__dict__[trait]).lower().replace(" ","")) + trait_vals.append(LipdToOntology(geos.__dict__[trait]).lower().replace(" ", "")) except: trait_vals.append(None) else: trait_vals.append(None) - #trait_vals = [LipdToOntology(geos.__dict__[trait]).lower().replace(" ","") if trait in geos.__dict__.keys() else None for geos in + # trait_vals = [LipdToOntology(geos.__dict__[trait]).lower().replace(" ","") if trait in geos.__dict__.keys() else None for geos in # geo_series_list] - else: + else: trait_vals = [geos.__dict__[trait] if trait in geos.__dict__.keys() else None for geos in geo_series_list] value_d[trait] = [trait_val if trait_val != 'None' else None for trait_val in trait_vals] - geos_df = pd.DataFrame(value_d) if type(d) == dict: for trait in d.keys(): @@ -571,7 +569,7 @@ def scatter_map(geos, hue='archiveType', size=None, marker='archiveType', edgeco proj_default=True, projection='auto', crit_dist=5000, background=True, borders=False, coastline=True, rivers=False, lakes=False, ocean=True, land=True, figsize=None, scatter_kwargs=None, gridspec_kwargs=None, extent='global', - lgd_kwargs=None, legend=True, colorbar=True, cmap=None, + lgd_kwargs=None, legend=True, colorbar=True, cmap=None, color_scale_type=None, fig=None, gs_slot=None, **kwargs): ''' @@ -648,6 +646,10 @@ def scatter_map(geos, hue='archiveType', size=None, marker='archiveType', edgeco Whether the draw a colorbar on the figure if the data associated with hue are numeric. Default is True. + color_scale_type : str, optional + Setting to "discrete" will force a discrete color scale with a default bin number of max(11, n) where n=number of unique values$^{\frac{1}{2}}$ + Default is None + lgd_kwargs : dict, optional Dictionary of arguments for `matplotlib.pyplot.legend `_. @@ -676,16 +678,13 @@ def scatter_map(geos, hue='archiveType', size=None, marker='archiveType', edgeco gridspec_kwargs : dict, optional Function assumes the possibility of a colorbar, map, and legend. A list of floats associated with the keyword `width_ratios` will assume the first (index=0) is the relative width of the colorbar, the second to last (index = -2) is the relative width of the map, and the last (index = -1) is the relative width of the area for the legend. - For information about Gridspec configuration, refer to `Matplotlib documentation _. The default is None. + For information about Gridspec configuration, refer to `Matplotlib documentation `_. The default is None. kwargs: dict, optional - 'missing_val_hue', 'missing_val_marker', 'missing_val_label' can all be used to change the way missing values are represented ('k', '?', are default hue and marker values will be associated with the label: 'missing'). - 'hue_mapping' and 'marker_mapping' can be used to submit dictionaries mapping hue values to colors and marker values to markers. Does not replace passing a string value for hue or marker. + - 'scalar_mappable' can be used to pass a matplotlib scalar mappable. See pyleoclim.utils.plotting.make_scalar_mappable for documentation on using the Pyleoclim utility, or the `Matplotlib tutorial on customizing colorbars `_. - Raises - ------ - TypeError - DESCRIPTION. Returns ------- @@ -699,66 +698,6 @@ def scatter_map(geos, hue='archiveType', size=None, marker='archiveType', edgeco ''' - def keep_center_colormap(cmap, vmin, vmax, center=0): - - vmin = vmin - center - vmax = vmax - center - - vdelta = max([.2 * abs(vmin), .2 * abs(vmax)]) - vmax = vmax + .2 * vdelta - vmin = vmin - .2 * vdelta - - dv = max(-vmin, vmax) * 2 - N = int(256 * dv / (vmax - vmin)) - cont_map = cm.get_cmap(cmap, N) - newcolors = cont_map(np.linspace(0, 1, N)) - beg = int((dv / 2 + vmin) * N / dv) - end = N - int((dv / 2 - vmax) * N / dv) - newmap = mpl.colors.ListedColormap(newcolors[beg:end]) - - return newmap - - def make_scalar_mappable(cmap=None, hue_vect=None, n=None, norm_kwargs=None): - - if type(hue_vect) in [np.ndarray, pd.Series, list]: - ax_cmap = None - ax_norm = None - - if type(norm_kwargs) != dict: - norm_kwargs = {} - if 'vcenter' not in norm_kwargs.keys(): - norm_kwargs['vcenter'] = 0 - if 'clip' not in norm_kwargs.keys(): - norm_kwargs['clip'] = False - - if np.any((norm_kwargs['vcenter'] < max(hue_vect)) | (norm_kwargs['vcenter'] > min(hue_vect))) == True: - if cmap is None: - cmap = 'vlag' - # ax_cmap = keep_center_colormap(cmap, min(hue_vect), max(hue_vect), center=0) - ax_norm = mpl.colors.CenteredNorm( - **norm_kwargs) # vcenter=0, clip=False)#TwoSlopeNorm(0, vmin=min(hue_vect), vmax=max(hue_vect)) # - else: - ax_norm = mpl.colors.Normalize(vmin=min(hue_vect), vmax=max(hue_vect), clip=False) - if cmap is None: - cmap = 'viridis' - - if ax_cmap is None: - if type(cmap) == list: - if n is None: - ax_cmap = mpl.colors.LinearSegmentedColormap.from_list("MyCmapName", cmap) - else: - ax_cmap = mpl.colors.LinearSegmentedColormap.from_list("MyCmapName", cmap, N=n) - elif type(cmap) == str: - if n is None: - ax_cmap = plt.get_cmap(cmap) - else: - ax_cmap = plt.get_cmap(cmap, n) - else: - print('what madness is this?') - ax_sm = cm.ScalarMappable(norm=ax_norm, cmap=ax_cmap) - - return ax_sm - # def make_df(geo_ms, hue=None, marker=None, size=None, cols=None): # try: # geo_series_list = geo_ms.series_list @@ -782,22 +721,26 @@ def make_scalar_mappable(cmap=None, hue_vect=None, n=None, norm_kwargs=None): # return geos_df def plot_scatter(df=None, x=None, y=None, hue_var=None, size_var=None, marker_var=None, edgecolor='w', - ax=None, ax_d=None, proj=None, scatter_kwargs=None, legend=True, lgd_kwargs=None, colorbar=True, - fig=None, #gs_slot=None, + ax=None, ax_d=None, proj=None, scatter_kwargs=None, legend=True, lgd_kwargs=None, colorbar=None, + fig=None, color_scale_type=None, # gs_slot=None, cmap=None, **kwargs): scatter_kwargs = {} if type(scatter_kwargs) != dict else scatter_kwargs lgd_kwargs = {} if type(lgd_kwargs) != dict else lgd_kwargs + kwargs = {} if type(kwargs) != dict else kwargs norm_kwargs = kwargs.pop('norm_kwargs', {}) + ax_sm = kwargs.pop('scalar_mappable', None) - #plot_defaults = copy.copy(PLOT_DEFAULT) + palette = None + hue_norm = None + # if (color_scale_type is not None) and (colorbar is None): + # colorbar = True + + # plot_defaults = copy.copy(PLOT_DEFAULT) f = copy.copy(PLOT_DEFAULT) plot_defaults = CaseInsensitiveDict() for key, value in f.items(): - plot_defaults[key]=value - palette = None - hue_norm = None - ax_sm = None + plot_defaults[key] = value _df = df if len(_df) == 1: @@ -817,7 +760,7 @@ def plot_scatter(df=None, x=None, y=None, hue_var=None, size_var=None, marker_va fig = plt.figure(figsize=(20, 10)) ax = fig.add_subplot() - transform=ccrs.PlateCarree() + transform = ccrs.PlateCarree() # if type(ax) == cartopy.mpl.geoaxes.GeoAxes: # transform=ccrs.PlateCarree() # if proj is not None: @@ -836,8 +779,21 @@ def plot_scatter(df=None, x=None, y=None, hue_var=None, size_var=None, marker_va scatter_kwargs['edgecolor'] = edgecolor hue_var = hue_var if hue_var in _df.columns else None + hue_var_type_numeric = False + if hue_var is not None: + hue_var_type_numeric = all(isinstance(i, (int, float)) for i in _df[_df[hue_var] != missing_val][ + hue_var]) # pd.to_numeric(_df[hue_var], errors='coerce').notnull().all() + marker_var = marker_var if marker_var in _df.columns else None + marker_var_type_numeric = False + if marker_var is not None: + marker_var_type_numeric = pd.to_numeric(_df[marker_var], errors='coerce').notnull().all() + size_var = size_var if size_var in _df.columns else None + size_var_type_numeric = False + if size_var is not None: + size_var_type_numeric = all(isinstance(i, (int, float)) for i in _df[_df[size_var] != missing_val][ + size_var]) # pd.to_numeric(_df[size_var], errors='coerce').notnull().all() trait_vars = [trait_var for trait_var in [hue_var, marker_var, size_var] if ((trait_var != None) and (trait_var in _df.columns))] @@ -849,7 +805,7 @@ def plot_scatter(df=None, x=None, y=None, hue_var=None, size_var=None, marker_va if _df[trait_var] in [None, 'None']: _df[trait_var] = missing_val - if size_var == None: + if size_var is None: scatter_kwargs['s'] = scatter_kwargs['s'] if 's' in scatter_kwargs else missing_d['size'] else: sizes = [size_val for size_val in _df[size_var].values if size_val != missing_val] @@ -860,7 +816,8 @@ def plot_scatter(df=None, x=None, y=None, hue_var=None, size_var=None, marker_va # if size does vary, filter out missing values; at present no strategy to depict missing size values _df = _df[_df[size_var] != missing_val] scatter_kwargs['sizes'] = (20, 200) if 'sizes' not in scatter_kwargs else scatter_kwargs['sizes'] - scatter_kwargs['size_norm'] = scatter_kwargs['size_norm'] if 'size_norm' in scatter_kwargs else (_df[size_var].min(), _df[size_var].max()) + scatter_kwargs['size_norm'] = scatter_kwargs['size_norm'] if 'size_norm' in scatter_kwargs else ( + _df[size_var].min(), _df[size_var].max()) trait_vars = [trait_var for trait_var in [hue_var, marker_var, size_var] if trait_var != None] # mapping between marker styles and marker values @@ -913,30 +870,49 @@ def plot_scatter(df=None, x=None, y=None, hue_var=None, size_var=None, marker_va # use hue mapping if supplied if 'hue_mapping' in kwargs: palette = kwargs['hue_mapping'] + # there should be different control for discrete and continuous hue elif hue_var == 'archiveType': palette = {key: value[0] for key, value in plot_defaults.items()} - colorbar = False - elif type(hue_var) == str: + elif isinstance(hue_var,str): #hue_var) == str: hue_data = _df[_df[hue_var] != missing_val] - if cmap != None: - palette = cmap - if len(hue_data[hue_var]) > 0: - trait_val_types = [True if type(val) in (np.str_, str) else False for val in hue_data[hue_var]] - if True in trait_val_types: - colorbar = False - if len(hue_data[hue_var].unique()) < 20: - palette = 'tab20' + # If scalar mappable was passed, try to extract components. + if ax_sm is not None: + try: + palette = ax_sm.cmap + except: + ax_sm = None # if can't extract a palette, the scalar mappable is not helpful so set to None to trigger normal flow + try: + hue_norm = ax_sm.norm + except: + hue_norm = None + + if ax_sm is None: + if cmap != None: + palette = cmap + if len(hue_data[hue_var]) > 0: + if hue_var_type_numeric is not True: + # trait_val_types = [True if type(val) in (np.str_, str) else False for val in hue_data[hue_var]] + # if True in trait_val_types: + colorbar = False + if len(hue_data[hue_var].unique()) < 20: + palette = 'tab20' + else: + if palette is None: + palette = 'viridis' + ax_sm = make_scalar_mappable(cmap=palette, hue_vect=None, n=len(hue_data[hue_var].unique()), + norm_kwargs=norm_kwargs) + palette = ax_sm.cmap else: - if palette is None: - palette = 'viridis' - ax_sm = make_scalar_mappable(cmap=palette, hue_vect=None, n=len(hue_data[hue_var].unique()), - norm_kwargs = norm_kwargs) - palette = ax_sm.cmap - else: - if ((type(palette) in [str, list]) or (palette == None)): - ax_sm = make_scalar_mappable(cmap=palette, hue_vect=hue_data[hue_var], norm_kwargs = norm_kwargs) - palette = ax_sm.cmap - hue_norm = ax_sm.norm # .autoscale(hue_data[hue_var]) + if ((type(palette) in [str, list]) or (palette is None)): + if color_scale_type == 'discrete': + n = max(10, int(np.ceil(np.sqrt(len(hue_data[hue_var].unique()))))) + ax_sm = make_scalar_mappable(cmap=palette, hue_vect=hue_data[hue_var], n=n, + norm_kwargs=norm_kwargs) + else: + ax_sm = make_scalar_mappable(cmap=palette, hue_vect=hue_data[hue_var], + norm_kwargs=norm_kwargs) + palette = ax_sm.cmap + hue_norm = ax_sm.norm # .autoscale(hue_data[hue_var]) if ((type(hue_var) == str) and (type(palette) == dict)): residual_traits = [trait for trait in _df[hue_var].unique() if @@ -950,35 +926,42 @@ def plot_scatter(df=None, x=None, y=None, hue_var=None, size_var=None, marker_va if type(hue_var) == str: scatter_kwargs['zorder'] = 13 hue_data = _df[_df[hue_var] == missing_val] - if len(hue_data)> 0: + if len(hue_data) > 0: sns.scatterplot(data=hue_data, x=x, y=y, hue=hue_var, size=size_var, - style=marker_var, transform=transform, #change to transform=scatter_kwargs['transform'] + style=marker_var, transform=transform, + # change to transform=scatter_kwargs['transform'] palette=[missing_d['hue'] for ik in range(len(hue_data))], ax=ax, **scatter_kwargs) missing_handles, missing_labels = ax.get_legend_handles_labels() else: - missing_handles, missing_labels=[],[] + missing_handles, missing_labels = [], [] scatter_kwargs['zorder'] = 14 hue_data = _df[_df[hue_var] != missing_val] if hue_norm is not None: scatter_kwargs['hue_norm'] = hue_norm - sns.scatterplot(data=hue_data, x=x, y=y, hue=hue_var, size=size_var,transform=transform, #change to transform=scatter_kwargs['transform'] + + sns.scatterplot(data=hue_data, x=x, y=y, hue=hue_var, size=size_var, transform=transform, + # change to transform=scatter_kwargs['transform'] style=marker_var, palette=palette, ax=ax, **scatter_kwargs) else: scatter_kwargs['zorder'] = 13 scatter_kwargs['c'] = missing_d['hue'] - sns.scatterplot(data=_df, x=x, y=y, hue=hue_var, size=size_var, transform=transform, #change to transform=scatter_kwargs['transform'] + sns.scatterplot(data=_df, x=x, y=y, hue=hue_var, size=size_var, transform=transform, + # change to transform=scatter_kwargs['transform'] style=marker_var, ax=ax, **scatter_kwargs) missing_handles, missing_labels = ax.get_legend_handles_labels() + # h, l= consolidate_legends([ax], hue=hue_var, style =marker_var, size=size_var, colorbar=colorbar) h, l = ax.get_legend_handles_labels() - if len(l) == 0: + if ((len(l) == 2) and (l[-1] == 'missing')) or (len(l) < 2): legend = False - if legend == True: + # if legend is True, prep legend content + d_leg = {} + if legend is True: han = copy.copy(h[0]) han.set_alpha(0) blank_handle = han @@ -997,8 +980,6 @@ def plot_scatter(df=None, x=None, y=None, hue_var=None, size_var=None, marker_va # legend has one section for each trait but ax.get_legend_handles_labels() yields straight lists # This code reorganizes legend information hierarchically starting with available and adding values # from missing as needed - d_leg = {} - for pair in [(available_handles, available_labels), (missing_handles, missing_labels)]: pair_h, pair_l = pair[0], pair[1] for ik, label in enumerate(pair_l): @@ -1019,89 +1000,98 @@ def plot_scatter(df=None, x=None, y=None, hue_var=None, size_var=None, marker_va d_leg[key]['labels'].append(_label) d_leg[key]['handles'].append(pair_h[ik]) - if ((colorbar==True) and (hue_var in d_leg.keys()) and (ax_sm is not None)): - if ax_cb != None: - - plt.colorbar(ax_sm, cax=ax_cb, orientation='vertical', label=hue_var, - fraction=.6, - shrink=.4) - ax_cb.yaxis.set_label_position('left') - ax_cb.yaxis.set_ticks_position('left') - else: - plt.colorbar(ax_sm, ax=ax, orientation='vertical', label=hue_var, - shrink=.6) # ,label=colorbar_units, - if len(d_leg.keys()) == 1: - ax.legend().remove() - legend=False - ax_leg.remove() - ax_d.pop('leg', None) - else: - d_leg.pop(hue_var, None) - - if legend == True: - # Finally rebuild legend in single list with formatted section headers - handles, labels = [], [] - headers = True - if ((len(d_leg) == 1) and ('label' in d_leg.keys())): - headers = False - headers = lgd_kwargs.pop('headers', headers) - - for key in d_leg: - han = copy.copy(h[0]) - han.set_alpha(0) - if headers == True: - handles.append(han) - # labels.append('$\\bf{}$'.format('{' + key + '}')) - labels.append(key) - - tmp_labels, tmp_handles = [], [] - tmp_labels_missing, tmp_handles_missing = [], [] - for ik, label in enumerate(d_leg[key]['labels']): - if label == 'missing': - tmp_labels_missing.append(label) - tmp_handles_missing.append(d_leg[key]['handles'][ik]) - else: - tmp_labels.append(label) - tmp_handles.append(d_leg[key]['handles'][ik]) + # if colorbar is True, hue will be removed from legend content, or removed entirely if there is no other + if ((colorbar is True) and (hue_var in d_leg.keys()) and (ax_sm is not None)): + d_leg.pop(hue_var, None) + # if len(d_leg.keys()) == 1: + # # ax.legend().remove() + # legend = False + # # ax_leg.remove() + # # ax_d.pop('leg', None) + # else: + # d_leg.pop(hue_var, None) + + if (legend is True) and (len(d_leg.keys()) > 0): + # Finally rebuild legend in single list with formatted section headers + handles, labels = [], [] + headers = True + if ((len(d_leg) == 1) and ('label' in d_leg.keys())): + headers = False + headers = lgd_kwargs.pop('headers', headers) + + for key in d_leg.keys(): + han = copy.copy(h[0]) + han.set_alpha(0) + if headers is True: + handles.append(han) + # labels.append('$\\bf{}$'.format('{' + key + '}')) + labels.append(key) + + tmp_labels, tmp_handles = [], [] + tmp_labels_missing, tmp_handles_missing = [], [] + for ik, label in enumerate(d_leg[key]['labels']): + if label == 'missing': + tmp_labels_missing.append(label) + tmp_handles_missing.append(d_leg[key]['handles'][ik]) + else: + tmp_labels.append(label) + tmp_handles.append(d_leg[key]['handles'][ik]) - tmp_labels += tmp_labels_missing - tmp_handles += tmp_handles_missing + tmp_labels += tmp_labels_missing + tmp_handles += tmp_handles_missing - handles += tmp_handles - labels += tmp_labels + handles += tmp_handles + labels += tmp_labels - handles.append(han) - labels.append('') + handles.append(han) + labels.append('') - if 'loc' not in lgd_kwargs: - lgd_kwargs['loc'] = 'upper left' - if 'bbox_to_anchor' not in lgd_kwargs: - lgd_kwargs['bbox_to_anchor'] = (-.1,1)#(1, 1) + if 'loc' not in lgd_kwargs: + lgd_kwargs['loc'] = 'upper left' + if 'bbox_to_anchor' not in lgd_kwargs: + lgd_kwargs['bbox_to_anchor'] = (-.1, 1) # (1, 1) - built_legend = ax_leg.legend(handles, labels, **lgd_kwargs) - if headers == True: - for _text in built_legend.get_texts(): - if _text._text in d_leg.keys(): - _text.set_weight('bold') + built_legend = ax_leg.legend(handles, labels, **lgd_kwargs) + if headers is True: + for _text in built_legend.get_texts(): + if _text._text in d_leg.keys(): + _text.set_weight('bold') + ax_leg.add_artist(built_legend) - ax_leg.add_artist(built_legend) - ax_leg.set_axis_off() - ax.legend().remove() - if colorbar == False: - ax_cb.remove() - ax_d.pop('cb', None) + ax_leg.set_axis_off() + ax.legend().remove() + # if colorbar == False: + # ax_cb.remove() + # ax_d.pop('cb', None) else: ax.legend().remove() - ax_d.pop('cb', None) + # ax_d.pop('cb', None) ax_d.pop('leg', None) - for _ax in [ax_cb, ax_leg]: - if type(_ax) != None: - _ax.remove() + ax_leg.remove() + # for _ax in [ax_cb, ax_leg]: + # if type(_ax) != None: + # _ax.remove() + + # if colorbar is true, make colorbar + if ((colorbar is True) and (ax_sm is not None)): + # make colorbar + if ax_cb is not None: + plt.colorbar(ax_sm, cax=ax_cb, orientation='vertical', label=hue_var, + fraction=.6, + shrink=.4) + ax_cb.yaxis.set_label_position('left') + ax_cb.yaxis.set_ticks_position('left') + else: + plt.colorbar(ax_sm, ax=ax, orientation='vertical', label=hue_var, + shrink=.6) # ,label=colorbar_units, + elif colorbar is False: + ax_cb.remove() + ax_d.pop('cb', None) # a little squirrely that this function might return different types - if len(ax_d) >0:#== dict: + if len(ax_d) > 0: # == dict: if 'map' in ax_d.keys(): ax_d['map'] = ax if 'cb' in ax_d.keys(): @@ -1110,14 +1100,15 @@ def plot_scatter(df=None, x=None, y=None, hue_var=None, size_var=None, marker_va ax_d['leg'] = ax_leg return fig, ax_d else: - return fig, {'map':ax} - ###### End of plot_scatter + return fig, {'map': ax} + + ###### End of plot_scatter # from ..core.multiplegeoseries import MultipleGeoSeries # from ..core.geoseries import GeoSeries # if geos is not - if type(geos) != pd.DataFrame:#in [MultipleGeoSeries, GeoSeries]: + if type(geos) != pd.DataFrame: # in [MultipleGeoSeries, GeoSeries]: df = make_df(geos, hue=hue, marker=marker, size=size) elif type(geos) == pd.DataFrame: df = geos @@ -1128,6 +1119,8 @@ def plot_scatter(df=None, x=None, y=None, hue_var=None, size_var=None, marker_va gridspec_kwargs = {} if type(gridspec_kwargs) != dict else gridspec_kwargs scatter_kwargs = {} if type(scatter_kwargs) != dict else scatter_kwargs + if 'marker_var' in scatter_kwargs: + marker_var = scatter_kwargs.pop('marker_var') # scatter_kwargs['transform'] = ccrs.PlateCarree() lgd_kwargs = {} if type(lgd_kwargs) != dict else lgd_kwargs @@ -1141,9 +1134,9 @@ def plot_scatter(df=None, x=None, y=None, hue_var=None, size_var=None, marker_va df['lon'].values, crit_dist=crit_dist) if figsize == None: if projection == 'Robinson': - figsize = (18,6) + figsize = (18, 6) if projection == 'Orthographic': - figsize = (16,6) + figsize = (16, 6) # set the projection proj = set_proj(projection=projection, proj_default=proj_default) @@ -1171,18 +1164,18 @@ def plot_scatter(df=None, x=None, y=None, hue_var=None, size_var=None, marker_va # use subgridspecs to encourage the slot for the map to have an aspect ratio closer to that of the projection if gs_slot == None: - _gs = gridspec.GridSpec(1, 1)#, **gridspec_kwargs) + _gs = gridspec.GridSpec(1, 1) # , **gridspec_kwargs) gs_slot = _gs[0] if projection == 'Robinson': gs_sub = gs_slot.subgridspec(1, 3) - gs_subslot = gs_sub[0,:] + gs_subslot = gs_sub[0, :] elif projection == 'Orthographic': gs_sub = gs_slot.subgridspec(3, 6) gs_subslot = gs_sub[:, 1:5] else: gs_subslot = gs_slot - num_subplots = 1+legend+colorbar + num_subplots = 1 + legend + colorbar if 'width_ratios' in gridspec_kwargs: if len(gridspec_kwargs['width_ratios']) < num_subplots: print('Please respecify gridspec width_ratios. Reverting to defaults.') @@ -1190,8 +1183,10 @@ def plot_scatter(df=None, x=None, y=None, hue_var=None, size_var=None, marker_va else: gridspec_kwargs['width_ratios'] = [.7, 16, 6] - gridspec_kwargs['width_ratios'] = gridspec_kwargs['width_ratios'] if 'width_ratios' in gridspec_kwargs else [.7,.05,16, 5] - gs = gs_subslot.subgridspec(1, len(gridspec_kwargs['width_ratios']), **gridspec_kwargs) + gridspec_kwargs['width_ratios'] = gridspec_kwargs['width_ratios'] if 'width_ratios' in gridspec_kwargs else [.7, + .05, + 16, 5] + gs = gs_subslot.subgridspec(1, len(gridspec_kwargs['width_ratios']), **gridspec_kwargs) ax_d['cb'] = fig.add_subplot(gs[0]) ax_d['map'] = fig.add_subplot(gs[-2], projection=proj) @@ -1204,27 +1199,31 @@ def plot_scatter(df=None, x=None, y=None, hue_var=None, size_var=None, marker_va ax_d['map'].stock_img() # Other extra information - feature_d = {'borders':borders, 'lakes':lakes, 'rivers':rivers, 'land':land, 'ocean':ocean, 'coastline':coastline} - feature_d = {key:(value if type(value)==dict else {}) for key, value in feature_d.items() if value !=False } - feature_spec_defaults = {'borders':dict(alpha=.5, linewidths=(.5,)), 'lakes':dict(alpha=0.25), - 'rivers': {}, 'land':dict(alpha=0.5), 'ocean':dict(alpha=0.25), 'coastline':dict(linewidths=(.5,))} + feature_d = {'borders': borders, 'lakes': lakes, 'rivers': rivers, 'land': land, 'ocean': ocean, + 'coastline': coastline} + feature_d = {key: (value if type(value) == dict else {}) for key, value in feature_d.items() if value != False} + feature_spec_defaults = {'borders': dict(alpha=.5, linewidths=(.5,)), 'lakes': dict(alpha=0.25), + 'rivers': {}, 'land': dict(alpha=0.5), 'ocean': dict(alpha=0.25), + 'coastline': dict(linewidths=(.5,))} for key in feature_d.keys(): feature_spec_defaults[key].update(feature_d[key]) - feature_types = {'borders':cfeature.BORDERS, 'coastline':cfeature.COASTLINE, 'lakes':cfeature.LAKES, - 'rivers':cfeature.RIVERS, 'land':cfeature.LAND, 'ocean':cfeature.OCEAN} + feature_types = {'borders': cfeature.BORDERS, 'coastline': cfeature.COASTLINE, 'lakes': cfeature.LAKES, + 'rivers': cfeature.RIVERS, 'land': cfeature.LAND, 'ocean': cfeature.OCEAN} for feature in feature_d.keys(): ax_d['map'].add_feature(feature_types[feature], **feature_spec_defaults[feature]) if extent == 'global': ax_d['map'].set_global() - elif isinstance(extent, list) and len(extent)==4: - ax_d['map'].set_extent(extent,crs=ccrs.PlateCarree()) - + elif isinstance(extent, list) and len(extent) == 4: + ax_d['map'].set_extent(extent, crs=ccrs.PlateCarree()) + x = 'lon' y = 'lat' - _, ax_d = plot_scatter(df=df, x=x, y=y, hue_var=hue, size_var=size, marker_var=marker, ax_d=ax_d, proj=None, edgecolor=edgecolor, - cmap=cmap, scatter_kwargs=scatter_kwargs, legend=legend, lgd_kwargs=lgd_kwargs, **kwargs) # , **kwargs) + _, ax_d = plot_scatter(df=df, x=x, y=y, hue_var=hue, size_var=size, marker_var=marker, ax_d=ax_d, proj=None, + edgecolor=edgecolor, colorbar=colorbar, color_scale_type=color_scale_type, + cmap=cmap, scatter_kwargs=scatter_kwargs, legend=legend, lgd_kwargs=lgd_kwargs, + **kwargs) # , **kwargs) return fig, ax_d @@ -1324,11 +1323,12 @@ def within_distance(distance, radius): def lon_360_to_180(x): return (x + 180) % 360 - 180 + def lon_180_to_360(x): return x % 360 -def centroid_coords(lat,lon, true_centroid=False): +def centroid_coords(lat, lon, true_centroid=False): ''' Computes the centroid of the geographic coordinates via Shapely. @@ -1359,9 +1359,10 @@ def centroid_coords(lat,lon, true_centroid=False): lon = lon_360_to_180(np.array(lon)) lat = np.array(lat) if len(lon) >= 4: - p = Polygon([(x, y) for (x,y) in zip(lon,lat)]) + p = Polygon([(x, y) for (x, y) in zip(lon, lat)]) if true_centroid: - clat = p.centroid.y; clon = p.centroid.x + clat = p.centroid.y; + clon = p.centroid.x else: clat = p.representative_point().y clon = p.representative_point().x diff --git a/pyleoclim/utils/plotting.py b/pyleoclim/utils/plotting.py index a91086a3..2fbf7cd5 100644 --- a/pyleoclim/utils/plotting.py +++ b/pyleoclim/utils/plotting.py @@ -9,11 +9,14 @@ import matplotlib.pyplot as plt import pathlib import matplotlib as mpl +from matplotlib import cm import numpy as np import pandas as pd +import collections.abc from ..utils import lipdutils + # import pandas as pd # from matplotlib.patches import Rectangle # from matplotlib.collections import PatchCollection @@ -605,12 +608,12 @@ def set_style(style='journal', font_scale=1.0, dpi=300): 'xtick.minor.width': 0, 'ytick.minor.width': 0, }) - + elif 'matplotlib' in style: - #mpl.rcParams.update(mpl.rcParamsDefault) + # mpl.rcParams.update(mpl.rcParamsDefault) style_dict.update({}) - - + + else: raise ValueError(f'Style [{style}] not availabel!') @@ -882,11 +885,11 @@ def get_label_width(ax, label, buffer=0., fontsize=10): """ Helper function to find width of text when rendered in ax object """ - + text = ax.text(0, 0, label, size=fontsize) width = text.get_window_extent(renderer=ax.figure.canvas.get_renderer()).width text.remove() # Remove the text used for measurement - + return width + buffer @@ -1113,3 +1116,313 @@ def label_intervals(fig, ax, labels, x_locs, orientation='north', overlapping_se ax.plot([x_locs[i], x_locs[i]], [0, slot_height], **linestyle_kwargs) return ax + + +def make_scalar_mappable(cmap=None, hue_vect=None, n=None, norm_kwargs=None): + """ + Create a ScalarMappable object for mapping scalar data to colors. + + This function configures and returns a ScalarMappable object based on the provided colormap (`cmap`), the scalar values (`hue_vect`), the number of discrete colors (`n`), and normalization parameters (`norm_kwargs`). It supports dynamic selection of normalization and colormap based on the input parameters and the range of scalar values. + + Parameters + ---------- + cmap : str, list, or None, optional + The colormap to use for mapping scalar data to colors. Can be a name of a matplotlib colormap (str), a list of color names, or None. If None, defaults to 'vlag' if conditions for centered normalization are met, otherwise 'viridis'. + hue_vect : np.ndarray, pd.Series, list, or None, optional + An array-like object containing the scalar values to be mapped to colors. These values are used to determine the range and center for normalization. + n : int or None, optional + Specifies the number of discrete colors in the colormap if `cmap` is provided as a list. If None, the number of colors is not explicitly set. + norm_kwargs : dict or None, optional + A dictionary containing keyword arguments for the normalization process, specifically supporting 'vcenter' and 'clip'. Defaults to {'vcenter': 0, 'clip': False} if not provided or if provided keys are missing. + + + Returns: + ------- + ax_sm : matplotlib.cm.ScalarMappable + The configured ScalarMappable object, which can be used to map scalar data to colors based on the specified colormap and normalization settings. + + + Examples + -------- + + .. jupyter-execute:: + + import pyleoclim as pyleo + import numpy as np + + scalar_values = np.random.randn(100) + sm = pyleo.utils.plotting.make_scalar_mappable(cmap='viridis', hue_vect=scalar_values) + # Now `sm` can be used with matplotlib plotting functions to map scalar values to colors. + + sm = pyleo.utils.plotting.make_scalar_mappable(cmap='viridis', hue_vect=scalar_values, n=10) + # This creates a ScalarMappable a discrete color scale. + + sm = pyleo.utils.plotting.make_scalar_mappable(cmap=['blue', 'white', 'red'], hue_vect=scalar_values, norm_kwargs={'vcenter': 0}) + # This creates a ScalarMappable with a custom linear segmented colormap and centered normalization. + + + """ + + # if type(hue_vect) in [np.ndarray, pd.Series, list]: + if isinstance(hue_vect, collections.abc.Iterable) and not isinstance(hue_vect, dict): + ax_cmap = None + ax_norm = None + + if type(norm_kwargs) != dict: + norm_kwargs = {} + if 'vcenter' not in norm_kwargs.keys(): + norm_kwargs['vcenter'] = 0 + if 'clip' not in norm_kwargs.keys(): + norm_kwargs['clip'] = False + + if all(isinstance(i, (int, float)) for i in hue_vect): + if np.any((norm_kwargs['vcenter'] < max(hue_vect)) | (norm_kwargs['vcenter'] > min(hue_vect))) == True: + if cmap is None: + cmap = 'vlag' + # ax_cmap = keep_center_colormap(cmap, min(hue_vect), max(hue_vect), center=0) + ax_norm = mpl.colors.CenteredNorm( + **norm_kwargs) # vcenter=0, clip=False)#TwoSlopeNorm(0, vmin=min(hue_vect), vmax=max(hue_vect)) # + else: + ax_norm = mpl.colors.Normalize(vmin=min(hue_vect), vmax=max(hue_vect), clip=False) + if cmap is None: + cmap = 'viridis' + + # if np.any((norm_kwargs['vcenter'] < max(hue_vect)) | (norm_kwargs['vcenter'] > min(hue_vect))) == True: + # if cmap is None: + # cmap = 'vlag' + # # ax_cmap = keep_center_colormap(cmap, min(hue_vect), max(hue_vect), center=0) + # ax_norm = mpl.colors.CenteredNorm( + # **norm_kwargs) # vcenter=0, clip=False)#TwoSlopeNorm(0, vmin=min(hue_vect), vmax=max(hue_vect)) # + # else: + # ax_norm = mpl.colors.Normalize(vmin=min(hue_vect), vmax=max(hue_vect), clip=False) + # if cmap is None: + # cmap = 'viridis' + + if ax_cmap is None: + if type(cmap) == list: + if n is None: + ax_cmap = mpl.colors.LinearSegmentedColormap.from_list("MyCmapName", cmap) + else: + ax_cmap = mpl.colors.LinearSegmentedColormap.from_list("MyCmapName", cmap, N=n) + elif type(cmap) == str: + if n is None: + ax_cmap = plt.get_cmap(cmap) + else: + ax_cmap = plt.get_cmap(cmap, n) + else: + print('what madness is this?') + ax_sm = cm.ScalarMappable(norm=ax_norm, cmap=ax_cmap) + + return ax_sm + + +import copy + + +def consolidate_legends(ax, split_btwn=True, hue='relation', style='exp_type', size=None, colorbar=False): + break_pts = [] + hs, ls = [], [] + for ip, _ax in enumerate(ax): + try: + legend = ax[ip].get_legend() + l2 = [_l._text for _l in legend.get_texts()] + h2 = legend.legendHandles + except: + ax[ip].legend() + h2, l2 = ax[ip].get_legend_handles_labels() + hs.append(h2) + ls.append(l2) + + break_pts2 = [] + for ib, _l in enumerate(l2): + if _l in [hue, style, size]: + break_pt2 = ib + break_pts2.append(ib) + break_pts2.append(len(l2)) + break_pts.append(break_pts2) + + labels = [] + handles = [] + + print(break_pts, ls) + for iq, bp_lst in enumerate(break_pts): # [1:]): + for ik, bp in enumerate(bp_lst): + if ik > 0: + if split_btwn is True: + labels.append('') + handles.append(copy.copy(hs[0][-1])) + handles[-1].set_alpha(0) + for im, _l in enumerate(ls[iq][:bp]): + if _l not in labels: + labels.append(_l) + handles.append(hs[iq][im]) + + if colorbar is True: + start = 0 + end = 0 + looking = True + for ib, _l in enumerate(labels): + if looking is True: + if _l == hue: + start = ib + elif _l in [size, style]: + end = ib + looking = False + + labels = labels[0:start] + labels[end:] # [start:] + handles = handles[0:start] + handles[end:] + + start = 0 + looking = True + for ib, _l in enumerate(labels): + if looking is True: + if _l == '': + start = ib + 1 + else: + looking = False + + labels = labels[start:] + handles = handles[start:] + + return handles, labels + + +# def consolidate_legends(ax, split_btwn=True, hue='relation', style='exp_type', size=None): +# break_pts = [] +# hs, ls = [], [] +# for ip, _ax in enumerate(ax): +# try: +# legend = ax[ip].get_legend() +# l2 = [_l._text for _l in legend.get_texts()] +# h2 = legend.legendHandles +# except: +# ax[ip].legend() +# h2, l2 = ax[ip].get_legend_handles_labels() +# hs.append(h2) +# ls.append(l2) +# +# break_pt2 = len(l2) +# for ib, _l in enumerate(l2): +# if _l in [hue, style, size]: +# break_pts.append(ib) +# +# labels = [] +# handles = [] +# +# if break_pts[0] ==0: +# break_pt_start = 1 +# else: +# break_pt_start = 0 +# +# handles += hs[0][:break_pts[break_pt_start]] +# labels += ls[0][:break_pts[break_pt_start]] +# +# print(ls) +# if len(break_pts)>1: +# for ik, bp in enumerate(break_pts[break_pt_start:]): +# if split_btwn is True: +# labels.append('') +# handles.append(copy.copy(handles[-1])) +# handles[-1].set_alpha(0) +# for im, _l in enumerate(ls[0][:bp]): +# if _l not in labels: +# labels.append(_l) +# handles.append(hs[0][im]) +# print(labels) + +# labels.append('') +# handles.append(copy.copy(handles[-1])) +# handles[-1].set_alpha(0) +# +# handles += hs[0][break_pts[break_pt_start]:] +# labels += ls[0][break_pts[break_pt_start]:] +# +# for ik, bp in enumerate(break_pts[1:]): +# ik += 1 +# if split_btwn is True: +# labels.append('') +# handles.append(copy.copy(handles[-1])) +# handles[-1].set_alpha(0) +# for im, _l in enumerate(ls[ik][bp:]): +# if _l not in labels: +# labels.append(_l) +# handles.append(hs[ik][im]) +# +# return handles, labels + + +def keep_center_colormap(cmap, vmin, vmax, center=0): + """ + + + Adjust a colormap so that a specific value remains centered, and extend its limits symmetrically. + + This function modifies a given colormap such that the color representing the 'center' value + remains at the center of the colormap. It does this by adjusting the minimum and maximum values + symmetrically around the center and ensuring that the colormap covers a range that is at least + 20% larger than the absolute range from the center to either the original minimum or maximum value. + This is particularly useful for visualizing data with a significant central value (e.g., zero in + anomaly maps) to ensure that the colormap visually represents deviations from this center in a balanced manner. + + + Parameters + ---------- + cmap : str or Colormap + The name of the colormap or a Colormap instance to be adjusted. + vmin : float + The original minimum value in the data range that the colormap should cover. + vmax : float + The original maximum value in the data range that the colormap should cover. + center : float, optional + The value that should be centered in the adjusted colormap. Default is 0. + + + Returns + ------- + newmap : matplotlib.colors.ListedColormap + A new colormap instance adjusted so that 'center' is in the middle of the colormap, + with its range symmetrically extended to ensure balanced representation of values around the center. + + + Notes + ----- + The adjustment involves shifting the original `vmin` and `vmax` values to be symmetric around the `center`, + then expanding the range by at least 20% to ensure that the colormap's central part accurately represents + the centered value across the data. This adjusted colormap can then be used for data visualization tasks + where maintaining a perceptual 'zero' or central reference point is important. + + + Examples + -------- + + .. jupyter-execute:: + + import matplotlib.pyplot as plt + + new_colormap = keep_center_colormap('RdBu', vmin=-300, vmax=300, center=0) + plt.imshow(data, cmap=new_colormap) + plt.colorbar() + + This will create and use a colormap where the value 0 is centered, and the colormap is adjusted + to symmetrically represent deviations from this center, making it suitable for displaying anomalies + or differences from a baseline in a visually balanced manner. + + + """ + + vmin = vmin - center + vmax = vmax - center + + vdelta = max([.2 * abs(vmin), .2 * abs(vmax)]) + vmax = vmax + .2 * vdelta + vmin = vmin - .2 * vdelta + + dv = max(-vmin, vmax) * 2 + N = int(256 * dv / (vmax - vmin)) + cont_map = cm.get_cmap(cmap, N) + newcolors = cont_map(np.linspace(0, 1, N)) + beg = int((dv / 2 + vmin) * N / dv) + end = N - int((dv / 2 - vmax) * N / dv) + newmap = mpl.colors.ListedColormap(newcolors[beg:end]) + + return newmap