From fce245e7d66945ce0ec3989c8932ee342f43bcfc Mon Sep 17 00:00:00 2001 From: John Omotani Date: Fri, 23 Aug 2019 13:37:19 +0100 Subject: [PATCH 1/8] Option for plotting gridlines in plot2d_wrapper --- xbout/plotting/plotfuncs.py | 27 +++++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/xbout/plotting/plotfuncs.py b/xbout/plotting/plotfuncs.py index 5503faaf..08f5649b 100644 --- a/xbout/plotting/plotfuncs.py +++ b/xbout/plotting/plotfuncs.py @@ -1,3 +1,5 @@ +import collections + import matplotlib import matplotlib.pyplot as plt import numpy as np @@ -35,8 +37,8 @@ def regions(da, ax=None, **kwargs): def plot2d_wrapper(da, method, *, ax=None, separatrix=True, targets=True, - add_limiter_hatching=True, cmap=None, vmin=None, vmax=None, - aspect=None, **kwargs): + add_limiter_hatching=True, gridlines=None, cmap=None, vmin=None, + vmax=None, aspect=None, **kwargs): """ Make a 2D plot using an xarray method, taking into account branch cuts (X-points). @@ -57,6 +59,9 @@ def plot2d_wrapper(da, method, *, ax=None, separatrix=True, targets=True, Draw solid lines at the target surfaces add_limiter_hatching : bool, optional Draw hatched areas at the targets + gridlines : bool or int, optional + If True, draw grid lines on the plot. If an int is passed, it is used as the + stride when plotting grid lines (to reduce the number on the plot) cmap : Matplotlib colormap, optional Color map to use for the plot vmin : float, optional @@ -169,6 +174,24 @@ def plot2d_wrapper(da, method, *, ax=None, separatrix=True, targets=True, extend = kwargs.get('extend', 'neither') fig.colorbar(artists[0], ax=ax, extend=extend) + if gridlines is not None: + if gridlines is True: + gridlines = (1, 1) + if not isinstance(gridlines, collections.Sequence): + gridlines = (gridlines, gridlines) + R_global = da['R'] + R_global.attrs['metadata'] = da.metadata + + Z_global = da['Z'] + Z_global.attrs['metadata'] = da.metadata + + R_regions = _decompose_regions(da['R']) + Z_regions = _decompose_regions(da['Z']) + + for R, Z in zip(R_regions, Z_regions): + plt.plot(R[::gridlines[0], :].T, Z[::gridlines[0], :].T, color='k', lw=0.1) + plt.plot(R[:, ::gridlines[1]], Z[:, ::gridlines[1]], color='k', lw=0.1) + ax.set_title(da.name) if _is_core_only(da): From 2fa32adbaa93e0f813439012466559ae6c722721 Mon Sep 17 00:00:00 2001 From: John Omotani Date: Tue, 10 Dec 2019 12:25:54 +0000 Subject: [PATCH 2/8] Make specification of 'gridlines' argument to plot2d_wrapper nicer Also specify in docstring what happens if a dict is passed. --- xbout/plotting/plotfuncs.py | 32 ++++++++++++++++++++++++-------- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/xbout/plotting/plotfuncs.py b/xbout/plotting/plotfuncs.py index 08f5649b..7e3fafd5 100644 --- a/xbout/plotting/plotfuncs.py +++ b/xbout/plotting/plotfuncs.py @@ -59,9 +59,12 @@ def plot2d_wrapper(da, method, *, ax=None, separatrix=True, targets=True, Draw solid lines at the target surfaces add_limiter_hatching : bool, optional Draw hatched areas at the targets - gridlines : bool or int, optional + gridlines : bool, int or slice or dict of bool, int or slice, optional If True, draw grid lines on the plot. If an int is passed, it is used as the - stride when plotting grid lines (to reduce the number on the plot) + stride when plotting grid lines (to reduce the number on the plot). If a slice is + passed it is used to select the grid lines to plot. + If a dict is passed, the 'x' entry (bool, int or slice) is used for the radial + grid-lines and the 'y' entry for the poloidal grid lines. cmap : Matplotlib colormap, optional Color map to use for the plot vmin : float, optional @@ -175,10 +178,20 @@ def plot2d_wrapper(da, method, *, ax=None, separatrix=True, targets=True, fig.colorbar(artists[0], ax=ax, extend=extend) if gridlines is not None: - if gridlines is True: - gridlines = (1, 1) - if not isinstance(gridlines, collections.Sequence): - gridlines = (gridlines, gridlines) + # convert gridlines to dict + if not isinstance(gridlines, dict): + gridlines = {'x': gridlines, 'y': gridlines} + + for key in gridlines: + value = gridlines[key] + if value is True: + gridlines[key] = slice(None) + elif isinstance(value, int): + gridlines[key] = slice(0, None, value) + elif not value is None: + if not isinstance(value, slice): + raise ValueError('Argument passed to gridlines must be bool, int or ' + 'slice. Got ' + str(value)) R_global = da['R'] R_global.attrs['metadata'] = da.metadata @@ -189,8 +202,11 @@ def plot2d_wrapper(da, method, *, ax=None, separatrix=True, targets=True, Z_regions = _decompose_regions(da['Z']) for R, Z in zip(R_regions, Z_regions): - plt.plot(R[::gridlines[0], :].T, Z[::gridlines[0], :].T, color='k', lw=0.1) - plt.plot(R[:, ::gridlines[1]], Z[:, ::gridlines[1]], color='k', lw=0.1) + if 'x' in gridlines and gridlines['x'] is not None: + plt.plot(R[:, gridlines['y']], Z[:, gridlines['y']], color='k', lw=0.1) + if 'y' in gridlines and gridlines['y'] is not None: + plt.plot(R[gridlines['y'], :].T, Z[gridlines['y'], :].T, color='k', + lw=0.1) ax.set_title(da.name) From 646e11052d03ae8c9ae2a94704e15275260f8d2f Mon Sep 17 00:00:00 2001 From: John Omotani Date: Thu, 5 Dec 2019 21:15:15 +0000 Subject: [PATCH 3/8] Arguments to make logarithmic color-scale in 2d plots 'norm' option allows any matplotlib.colors.Normalize instance to be passed, for general control of the color-scale. For convenience, also adds a 'logscale' option which can be set to True to create a standard log-scale if vmin and vmax are both positive or both negative, and a symmetric log-scale with a linear threshold of min(abs(vmin),abs(vmax))*1.e-5 if vmin and vmax have opposite signs. If a float is passed for 'logscale' and vmin and vmax have opposite signs then the value of 'logscale' is used instead of the default 1.e-5 to set the linear threshold. --- xbout/plotting/plotfuncs.py | 32 +++++++++++++++++++++++++++----- xbout/plotting/utils.py | 24 ++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 5 deletions(-) diff --git a/xbout/plotting/plotfuncs.py b/xbout/plotting/plotfuncs.py index 7e3fafd5..1d567688 100644 --- a/xbout/plotting/plotfuncs.py +++ b/xbout/plotting/plotfuncs.py @@ -6,7 +6,8 @@ import xarray as xr -from .utils import _decompose_regions, _is_core_only, plot_separatrices, plot_targets +from .utils import (_create_norm, _decompose_regions, _is_core_only, plot_separatrices, + plot_targets) def regions(da, ax=None, **kwargs): @@ -37,8 +38,9 @@ def regions(da, ax=None, **kwargs): def plot2d_wrapper(da, method, *, ax=None, separatrix=True, targets=True, - add_limiter_hatching=True, gridlines=None, cmap=None, vmin=None, - vmax=None, aspect=None, **kwargs): + add_limiter_hatching=True, gridlines=None, cmap=None, + norm = None, logscale=None, vmin=None, vmax=None, + aspect=None, **kwargs): """ Make a 2D plot using an xarray method, taking into account branch cuts (X-points). @@ -67,6 +69,15 @@ def plot2d_wrapper(da, method, *, ax=None, separatrix=True, targets=True, grid-lines and the 'y' entry for the poloidal grid lines. cmap : Matplotlib colormap, optional Color map to use for the plot + norm : matplotlib.colors.Normalize instance, optional + Normalization to use for the color scale. + Cannot be set at the same time as 'logscale' + logscale : bool or float, optional + If True, default to a logarithmic color scale instead of a linear one. + If a non-bool type is passed it is treated as a float used to set the linear + threshold of a symmetric logarithmic scale as + linthresh=min(abs(vmin),abs(vmax))*logscale, defaults to 1e-5 if True is passed. + Cannot be set at the same time as 'norm' vmin : float, optional Minimum value for the color scale vmax : float, optional @@ -119,12 +130,23 @@ def plot2d_wrapper(da, method, *, ax=None, separatrix=True, targets=True, vmin = np.min(levels) vmax = np.max(levels) + levels = kwargs.get('levels', 7) + if isinstance(levels, np.int): + levels = np.linspace(vmin, vmax, levels, endpoint=True) + # put levels back into kwargs + kwargs['levels'] = levels + else: + levels = np.array(list(levels)) + kwargs['levels'] = levels + vmin = np.min(levels) + vmax = np.max(levels) + # Need to create a colorscale that covers the range of values in the whole array. # Using the add_colorbar argument would create a separate color scale for each # separate region, which would not make sense. if method is xr.plot.contourf: # create colorbar - norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax) + norm = _create_norm(logscale, norm, vmin, vmax) sm = plt.cm.ScalarMappable(norm=norm, cmap=cmap) # make colorbar have only discrete levels # average the levels so that colors in the colorbar represent the intervals @@ -149,7 +171,7 @@ def plot2d_wrapper(da, method, *, ax=None, separatrix=True, targets=True, kwargs['vmax'] = vmax # create colorbar - norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax) + norm = _create_norm(logscale, norm, vmin, vmax) sm = plt.cm.ScalarMappable(norm=norm, cmap=cmap) sm.set_array([]) cmap = sm.get_cmap() diff --git a/xbout/plotting/utils.py b/xbout/plotting/utils.py index c76ced15..3c4ca217 100644 --- a/xbout/plotting/utils.py +++ b/xbout/plotting/utils.py @@ -1,9 +1,33 @@ import warnings +import matplotlib as mpl import numpy as np import xarray as xr +def _create_norm(logscale, norm, vmin, vmax): + if logscale: + if norm is not None: + raise ValueError("norm and logscale cannot both be passed at the same " + "time.") + if vmin*vmax > 0: + # vmin and vmax have the same sign, so can use standard log-scale + norm = mpl.colors.LogNorm(vmin=vmin, vmax=vmax) + else: + # vmin and vmax have opposite signs, so use symmetrical logarithmic scale + if not isinstance(logscale, bool): + linear_scale = logscale + else: + linear_scale = 1.e-5 + linear_threshold = min(abs(vmin), abs(vmax)) * linear_scale + norm = mpl.colors.SymLogNorm(linear_threshold, vmin=vmin, + vmax=vmax) + elif norm is None: + norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax) + + return norm + + def plot_separatrix(da, sep_pos, ax, radial_coord='x'): """ Plots the separatrix as a black dotted line. From 03a2a17ef2a10cb003e750607651de6f8d5b9be3 Mon Sep 17 00:00:00 2001 From: John Omotani Date: Tue, 10 Dec 2019 13:29:09 +0000 Subject: [PATCH 4/8] PEP8 fixes --- xbout/plotting/plotfuncs.py | 6 +++--- xbout/plotting/utils.py | 3 +-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/xbout/plotting/plotfuncs.py b/xbout/plotting/plotfuncs.py index 1d567688..901b388b 100644 --- a/xbout/plotting/plotfuncs.py +++ b/xbout/plotting/plotfuncs.py @@ -39,8 +39,8 @@ def regions(da, ax=None, **kwargs): def plot2d_wrapper(da, method, *, ax=None, separatrix=True, targets=True, add_limiter_hatching=True, gridlines=None, cmap=None, - norm = None, logscale=None, vmin=None, vmax=None, - aspect=None, **kwargs): + norm=None, logscale=None, vmin=None, vmax=None, aspect=None, + **kwargs): """ Make a 2D plot using an xarray method, taking into account branch cuts (X-points). @@ -210,7 +210,7 @@ def plot2d_wrapper(da, method, *, ax=None, separatrix=True, targets=True, gridlines[key] = slice(None) elif isinstance(value, int): gridlines[key] = slice(0, None, value) - elif not value is None: + elif value is not None: if not isinstance(value, slice): raise ValueError('Argument passed to gridlines must be bool, int or ' 'slice. Got ' + str(value)) diff --git a/xbout/plotting/utils.py b/xbout/plotting/utils.py index 3c4ca217..2f3816ab 100644 --- a/xbout/plotting/utils.py +++ b/xbout/plotting/utils.py @@ -20,8 +20,7 @@ def _create_norm(logscale, norm, vmin, vmax): else: linear_scale = 1.e-5 linear_threshold = min(abs(vmin), abs(vmax)) * linear_scale - norm = mpl.colors.SymLogNorm(linear_threshold, vmin=vmin, - vmax=vmax) + norm = mpl.colors.SymLogNorm(linear_threshold, vmin=vmin, vmax=vmax) elif norm is None: norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax) From 2312e2936307ea8605d32fc5f1e8fa48c6d30693 Mon Sep 17 00:00:00 2001 From: John Omotani Date: Tue, 10 Dec 2019 14:25:10 +0000 Subject: [PATCH 5/8] Tidy up plot2d_wrapper --- xbout/plotting/plotfuncs.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/xbout/plotting/plotfuncs.py b/xbout/plotting/plotfuncs.py index 901b388b..b3766e03 100644 --- a/xbout/plotting/plotfuncs.py +++ b/xbout/plotting/plotfuncs.py @@ -204,8 +204,7 @@ def plot2d_wrapper(da, method, *, ax=None, separatrix=True, targets=True, if not isinstance(gridlines, dict): gridlines = {'x': gridlines, 'y': gridlines} - for key in gridlines: - value = gridlines[key] + for key, value in gridlines: if value is True: gridlines[key] = slice(None) elif isinstance(value, int): @@ -213,7 +212,7 @@ def plot2d_wrapper(da, method, *, ax=None, separatrix=True, targets=True, elif value is not None: if not isinstance(value, slice): raise ValueError('Argument passed to gridlines must be bool, int or ' - 'slice. Got ' + str(value)) + 'slice. Got a ' + type(value) + ', ' + str(value)) R_global = da['R'] R_global.attrs['metadata'] = da.metadata From a801a5f149aa1a1c9fc52bbb264ad9a8c00c72c1 Mon Sep 17 00:00:00 2001 From: John Omotani Date: Tue, 10 Dec 2019 14:25:29 +0000 Subject: [PATCH 6/8] Make slicing and transposing more robust when plotting gridlines --- xbout/plotting/plotfuncs.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/xbout/plotting/plotfuncs.py b/xbout/plotting/plotfuncs.py index b3766e03..5219f10e 100644 --- a/xbout/plotting/plotfuncs.py +++ b/xbout/plotting/plotfuncs.py @@ -204,7 +204,7 @@ def plot2d_wrapper(da, method, *, ax=None, separatrix=True, targets=True, if not isinstance(gridlines, dict): gridlines = {'x': gridlines, 'y': gridlines} - for key, value in gridlines: + for key, value in gridlines.items(): if value is True: gridlines[key] = slice(None) elif isinstance(value, int): @@ -223,11 +223,25 @@ def plot2d_wrapper(da, method, *, ax=None, separatrix=True, targets=True, Z_regions = _decompose_regions(da['Z']) for R, Z in zip(R_regions, Z_regions): + if (not da.metadata['bout_xdim'] in R.dims + and not da.metadata['bout_ydim'] in R.dims): + # Small regions around X-point do not have segments in x- or y-directions, + # so skip + continue if 'x' in gridlines and gridlines['x'] is not None: - plt.plot(R[:, gridlines['y']], Z[:, gridlines['y']], color='k', lw=0.1) + # transpose in case Dataset or DataArray has been transposed away from the usual + # form + dim_order = (da.metadata['bout_xdim'], da.metadata['bout_ydim']) + yarg = {da.metadata['bout_ydim']: gridlines['x']} + plt.plot(R.isel(**yarg).transpose(*dim_order), + Z.isel(**yarg).transpose(*dim_order), color='k', lw=0.1) if 'y' in gridlines and gridlines['y'] is not None: - plt.plot(R[gridlines['y'], :].T, Z[gridlines['y'], :].T, color='k', - lw=0.1) + xarg = {da.metadata['bout_xdim']: gridlines['y']} + # Need to plot transposed arrays to make gridlines that go in the + # y-direction + dim_order = (da.metadata['bout_ydim'], da.metadata['bout_xdim']) + plt.plot(R.isel(**xarg).transpose(*dim_order), + Z.isel(**yarg).transpose(*dim_order), color='k', lw=0.1) ax.set_title(da.name) From 0f4f7b5d355e30c4ac98c1c65746e329ac37e791 Mon Sep 17 00:00:00 2001 From: John Omotani Date: Tue, 10 Dec 2019 14:41:19 +0000 Subject: [PATCH 7/8] Fix value of 'bout_xdim' when setting toroidal coordinates When setting torodial coordinates, a 'psi_poloidal' coordinate is added, but the x-dimension is not renamed (as psixy is a 2d array, although its values only vary in x). Therefore 'bout_xdim' should be 'x' (the name of the dimension), not 'psi_poloidal'. --- xbout/geometries.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/xbout/geometries.py b/xbout/geometries.py index 5837d660..d525fe27 100644 --- a/xbout/geometries.py +++ b/xbout/geometries.py @@ -152,7 +152,8 @@ def add_toroidal_geometry_coords(ds, *, coordinates=None, grid=None): # Record which dimensions 't', 'x', and 'y' were renamed to. ds.metadata['bout_tdim'] = coordinates['t'] - ds.metadata['bout_xdim'] = coordinates['x'] + # x dimension not renamed, so this is still 'x' + ds.metadata['bout_xdim'] = 'x' ds.metadata['bout_ydim'] = coordinates['y'] # If full data (not just grid file) then toroidal dim will be present From d265127f48a50cab89d2e7ea8cf62b00e694f1bb Mon Sep 17 00:00:00 2001 From: johnomotani Date: Tue, 10 Dec 2019 19:39:14 +0000 Subject: [PATCH 8/8] Tidy up check for if gridlines has 'x' or 'y' --- xbout/plotting/plotfuncs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xbout/plotting/plotfuncs.py b/xbout/plotting/plotfuncs.py index 5219f10e..6ef66b4e 100644 --- a/xbout/plotting/plotfuncs.py +++ b/xbout/plotting/plotfuncs.py @@ -228,14 +228,14 @@ def plot2d_wrapper(da, method, *, ax=None, separatrix=True, targets=True, # Small regions around X-point do not have segments in x- or y-directions, # so skip continue - if 'x' in gridlines and gridlines['x'] is not None: + if gridlines.get('x') is not None: # transpose in case Dataset or DataArray has been transposed away from the usual # form dim_order = (da.metadata['bout_xdim'], da.metadata['bout_ydim']) yarg = {da.metadata['bout_ydim']: gridlines['x']} plt.plot(R.isel(**yarg).transpose(*dim_order), Z.isel(**yarg).transpose(*dim_order), color='k', lw=0.1) - if 'y' in gridlines and gridlines['y'] is not None: + if gridlines.get('y') is not None: xarg = {da.metadata['bout_xdim']: gridlines['y']} # Need to plot transposed arrays to make gridlines that go in the # y-direction