diff --git a/xbout/geometries.py b/xbout/geometries.py index 5837d660..d525fe27 100644 --- a/xbout/geometries.py +++ b/xbout/geometries.py @@ -152,7 +152,8 @@ def add_toroidal_geometry_coords(ds, *, coordinates=None, grid=None): # Record which dimensions 't', 'x', and 'y' were renamed to. ds.metadata['bout_tdim'] = coordinates['t'] - ds.metadata['bout_xdim'] = coordinates['x'] + # x dimension not renamed, so this is still 'x' + ds.metadata['bout_xdim'] = 'x' ds.metadata['bout_ydim'] = coordinates['y'] # If full data (not just grid file) then toroidal dim will be present diff --git a/xbout/plotting/plotfuncs.py b/xbout/plotting/plotfuncs.py index 5503faaf..6ef66b4e 100644 --- a/xbout/plotting/plotfuncs.py +++ b/xbout/plotting/plotfuncs.py @@ -1,10 +1,13 @@ +import collections + import matplotlib import matplotlib.pyplot as plt import numpy as np import xarray as xr -from .utils import _decompose_regions, _is_core_only, plot_separatrices, plot_targets +from .utils import (_create_norm, _decompose_regions, _is_core_only, plot_separatrices, + plot_targets) def regions(da, ax=None, **kwargs): @@ -35,8 +38,9 @@ def regions(da, ax=None, **kwargs): def plot2d_wrapper(da, method, *, ax=None, separatrix=True, targets=True, - add_limiter_hatching=True, cmap=None, vmin=None, vmax=None, - aspect=None, **kwargs): + add_limiter_hatching=True, gridlines=None, cmap=None, + norm=None, logscale=None, vmin=None, vmax=None, aspect=None, + **kwargs): """ Make a 2D plot using an xarray method, taking into account branch cuts (X-points). @@ -57,8 +61,23 @@ def plot2d_wrapper(da, method, *, ax=None, separatrix=True, targets=True, Draw solid lines at the target surfaces add_limiter_hatching : bool, optional Draw hatched areas at the targets + gridlines : bool, int or slice or dict of bool, int or slice, optional + If True, draw grid lines on the plot. If an int is passed, it is used as the + stride when plotting grid lines (to reduce the number on the plot). If a slice is + passed it is used to select the grid lines to plot. + If a dict is passed, the 'x' entry (bool, int or slice) is used for the radial + grid-lines and the 'y' entry for the poloidal grid lines. cmap : Matplotlib colormap, optional Color map to use for the plot + norm : matplotlib.colors.Normalize instance, optional + Normalization to use for the color scale. + Cannot be set at the same time as 'logscale' + logscale : bool or float, optional + If True, default to a logarithmic color scale instead of a linear one. + If a non-bool type is passed it is treated as a float used to set the linear + threshold of a symmetric logarithmic scale as + linthresh=min(abs(vmin),abs(vmax))*logscale, defaults to 1e-5 if True is passed. + Cannot be set at the same time as 'norm' vmin : float, optional Minimum value for the color scale vmax : float, optional @@ -111,12 +130,23 @@ def plot2d_wrapper(da, method, *, ax=None, separatrix=True, targets=True, vmin = np.min(levels) vmax = np.max(levels) + levels = kwargs.get('levels', 7) + if isinstance(levels, np.int): + levels = np.linspace(vmin, vmax, levels, endpoint=True) + # put levels back into kwargs + kwargs['levels'] = levels + else: + levels = np.array(list(levels)) + kwargs['levels'] = levels + vmin = np.min(levels) + vmax = np.max(levels) + # Need to create a colorscale that covers the range of values in the whole array. # Using the add_colorbar argument would create a separate color scale for each # separate region, which would not make sense. if method is xr.plot.contourf: # create colorbar - norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax) + norm = _create_norm(logscale, norm, vmin, vmax) sm = plt.cm.ScalarMappable(norm=norm, cmap=cmap) # make colorbar have only discrete levels # average the levels so that colors in the colorbar represent the intervals @@ -141,7 +171,7 @@ def plot2d_wrapper(da, method, *, ax=None, separatrix=True, targets=True, kwargs['vmax'] = vmax # create colorbar - norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax) + norm = _create_norm(logscale, norm, vmin, vmax) sm = plt.cm.ScalarMappable(norm=norm, cmap=cmap) sm.set_array([]) cmap = sm.get_cmap() @@ -169,6 +199,50 @@ def plot2d_wrapper(da, method, *, ax=None, separatrix=True, targets=True, extend = kwargs.get('extend', 'neither') fig.colorbar(artists[0], ax=ax, extend=extend) + if gridlines is not None: + # convert gridlines to dict + if not isinstance(gridlines, dict): + gridlines = {'x': gridlines, 'y': gridlines} + + for key, value in gridlines.items(): + if value is True: + gridlines[key] = slice(None) + elif isinstance(value, int): + gridlines[key] = slice(0, None, value) + elif value is not None: + if not isinstance(value, slice): + raise ValueError('Argument passed to gridlines must be bool, int or ' + 'slice. Got a ' + type(value) + ', ' + str(value)) + R_global = da['R'] + R_global.attrs['metadata'] = da.metadata + + Z_global = da['Z'] + Z_global.attrs['metadata'] = da.metadata + + R_regions = _decompose_regions(da['R']) + Z_regions = _decompose_regions(da['Z']) + + for R, Z in zip(R_regions, Z_regions): + if (not da.metadata['bout_xdim'] in R.dims + and not da.metadata['bout_ydim'] in R.dims): + # Small regions around X-point do not have segments in x- or y-directions, + # so skip + continue + if gridlines.get('x') is not None: + # transpose in case Dataset or DataArray has been transposed away from the usual + # form + dim_order = (da.metadata['bout_xdim'], da.metadata['bout_ydim']) + yarg = {da.metadata['bout_ydim']: gridlines['x']} + plt.plot(R.isel(**yarg).transpose(*dim_order), + Z.isel(**yarg).transpose(*dim_order), color='k', lw=0.1) + if gridlines.get('y') is not None: + xarg = {da.metadata['bout_xdim']: gridlines['y']} + # Need to plot transposed arrays to make gridlines that go in the + # y-direction + dim_order = (da.metadata['bout_ydim'], da.metadata['bout_xdim']) + plt.plot(R.isel(**xarg).transpose(*dim_order), + Z.isel(**yarg).transpose(*dim_order), color='k', lw=0.1) + ax.set_title(da.name) if _is_core_only(da): diff --git a/xbout/plotting/utils.py b/xbout/plotting/utils.py index c76ced15..2f3816ab 100644 --- a/xbout/plotting/utils.py +++ b/xbout/plotting/utils.py @@ -1,9 +1,32 @@ import warnings +import matplotlib as mpl import numpy as np import xarray as xr +def _create_norm(logscale, norm, vmin, vmax): + if logscale: + if norm is not None: + raise ValueError("norm and logscale cannot both be passed at the same " + "time.") + if vmin*vmax > 0: + # vmin and vmax have the same sign, so can use standard log-scale + norm = mpl.colors.LogNorm(vmin=vmin, vmax=vmax) + else: + # vmin and vmax have opposite signs, so use symmetrical logarithmic scale + if not isinstance(logscale, bool): + linear_scale = logscale + else: + linear_scale = 1.e-5 + linear_threshold = min(abs(vmin), abs(vmax)) * linear_scale + norm = mpl.colors.SymLogNorm(linear_threshold, vmin=vmin, vmax=vmax) + elif norm is None: + norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax) + + return norm + + def plot_separatrix(da, sep_pos, ax, radial_coord='x'): """ Plots the separatrix as a black dotted line.