Skip to content

Commit

Permalink
Merge pull request #71 from boutproject/extra-poloidal-plot-options
Browse files Browse the repository at this point in the history
Extra options for poloidal plots
  • Loading branch information
TomNicholas authored Dec 10, 2019
2 parents 63bc17c + d265127 commit 00f527d
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 6 deletions.
3 changes: 2 additions & 1 deletion xbout/geometries.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,8 @@ def add_toroidal_geometry_coords(ds, *, coordinates=None, grid=None):

# Record which dimensions 't', 'x', and 'y' were renamed to.
ds.metadata['bout_tdim'] = coordinates['t']
ds.metadata['bout_xdim'] = coordinates['x']
# x dimension not renamed, so this is still 'x'
ds.metadata['bout_xdim'] = 'x'
ds.metadata['bout_ydim'] = coordinates['y']

# If full data (not just grid file) then toroidal dim will be present
Expand Down
84 changes: 79 additions & 5 deletions xbout/plotting/plotfuncs.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import collections

import matplotlib
import matplotlib.pyplot as plt
import numpy as np

import xarray as xr

from .utils import _decompose_regions, _is_core_only, plot_separatrices, plot_targets
from .utils import (_create_norm, _decompose_regions, _is_core_only, plot_separatrices,
plot_targets)


def regions(da, ax=None, **kwargs):
Expand Down Expand Up @@ -35,8 +38,9 @@ def regions(da, ax=None, **kwargs):


def plot2d_wrapper(da, method, *, ax=None, separatrix=True, targets=True,
add_limiter_hatching=True, cmap=None, vmin=None, vmax=None,
aspect=None, **kwargs):
add_limiter_hatching=True, gridlines=None, cmap=None,
norm=None, logscale=None, vmin=None, vmax=None, aspect=None,
**kwargs):
"""
Make a 2D plot using an xarray method, taking into account branch cuts (X-points).
Expand All @@ -57,8 +61,23 @@ def plot2d_wrapper(da, method, *, ax=None, separatrix=True, targets=True,
Draw solid lines at the target surfaces
add_limiter_hatching : bool, optional
Draw hatched areas at the targets
gridlines : bool, int or slice or dict of bool, int or slice, optional
If True, draw grid lines on the plot. If an int is passed, it is used as the
stride when plotting grid lines (to reduce the number on the plot). If a slice is
passed it is used to select the grid lines to plot.
If a dict is passed, the 'x' entry (bool, int or slice) is used for the radial
grid-lines and the 'y' entry for the poloidal grid lines.
cmap : Matplotlib colormap, optional
Color map to use for the plot
norm : matplotlib.colors.Normalize instance, optional
Normalization to use for the color scale.
Cannot be set at the same time as 'logscale'
logscale : bool or float, optional
If True, default to a logarithmic color scale instead of a linear one.
If a non-bool type is passed it is treated as a float used to set the linear
threshold of a symmetric logarithmic scale as
linthresh=min(abs(vmin),abs(vmax))*logscale, defaults to 1e-5 if True is passed.
Cannot be set at the same time as 'norm'
vmin : float, optional
Minimum value for the color scale
vmax : float, optional
Expand Down Expand Up @@ -111,12 +130,23 @@ def plot2d_wrapper(da, method, *, ax=None, separatrix=True, targets=True,
vmin = np.min(levels)
vmax = np.max(levels)

levels = kwargs.get('levels', 7)
if isinstance(levels, np.int):
levels = np.linspace(vmin, vmax, levels, endpoint=True)
# put levels back into kwargs
kwargs['levels'] = levels
else:
levels = np.array(list(levels))
kwargs['levels'] = levels
vmin = np.min(levels)
vmax = np.max(levels)

# Need to create a colorscale that covers the range of values in the whole array.
# Using the add_colorbar argument would create a separate color scale for each
# separate region, which would not make sense.
if method is xr.plot.contourf:
# create colorbar
norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax)
norm = _create_norm(logscale, norm, vmin, vmax)
sm = plt.cm.ScalarMappable(norm=norm, cmap=cmap)
# make colorbar have only discrete levels
# average the levels so that colors in the colorbar represent the intervals
Expand All @@ -141,7 +171,7 @@ def plot2d_wrapper(da, method, *, ax=None, separatrix=True, targets=True,
kwargs['vmax'] = vmax

# create colorbar
norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax)
norm = _create_norm(logscale, norm, vmin, vmax)
sm = plt.cm.ScalarMappable(norm=norm, cmap=cmap)
sm.set_array([])
cmap = sm.get_cmap()
Expand Down Expand Up @@ -169,6 +199,50 @@ def plot2d_wrapper(da, method, *, ax=None, separatrix=True, targets=True,
extend = kwargs.get('extend', 'neither')
fig.colorbar(artists[0], ax=ax, extend=extend)

if gridlines is not None:
# convert gridlines to dict
if not isinstance(gridlines, dict):
gridlines = {'x': gridlines, 'y': gridlines}

for key, value in gridlines.items():
if value is True:
gridlines[key] = slice(None)
elif isinstance(value, int):
gridlines[key] = slice(0, None, value)
elif value is not None:
if not isinstance(value, slice):
raise ValueError('Argument passed to gridlines must be bool, int or '
'slice. Got a ' + type(value) + ', ' + str(value))
R_global = da['R']
R_global.attrs['metadata'] = da.metadata

Z_global = da['Z']
Z_global.attrs['metadata'] = da.metadata

R_regions = _decompose_regions(da['R'])
Z_regions = _decompose_regions(da['Z'])

for R, Z in zip(R_regions, Z_regions):
if (not da.metadata['bout_xdim'] in R.dims
and not da.metadata['bout_ydim'] in R.dims):
# Small regions around X-point do not have segments in x- or y-directions,
# so skip
continue
if gridlines.get('x') is not None:
# transpose in case Dataset or DataArray has been transposed away from the usual
# form
dim_order = (da.metadata['bout_xdim'], da.metadata['bout_ydim'])
yarg = {da.metadata['bout_ydim']: gridlines['x']}
plt.plot(R.isel(**yarg).transpose(*dim_order),
Z.isel(**yarg).transpose(*dim_order), color='k', lw=0.1)
if gridlines.get('y') is not None:
xarg = {da.metadata['bout_xdim']: gridlines['y']}
# Need to plot transposed arrays to make gridlines that go in the
# y-direction
dim_order = (da.metadata['bout_ydim'], da.metadata['bout_xdim'])
plt.plot(R.isel(**xarg).transpose(*dim_order),
Z.isel(**yarg).transpose(*dim_order), color='k', lw=0.1)

ax.set_title(da.name)

if _is_core_only(da):
Expand Down
23 changes: 23 additions & 0 deletions xbout/plotting/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,32 @@
import warnings

import matplotlib as mpl
import numpy as np
import xarray as xr


def _create_norm(logscale, norm, vmin, vmax):
if logscale:
if norm is not None:
raise ValueError("norm and logscale cannot both be passed at the same "
"time.")
if vmin*vmax > 0:
# vmin and vmax have the same sign, so can use standard log-scale
norm = mpl.colors.LogNorm(vmin=vmin, vmax=vmax)
else:
# vmin and vmax have opposite signs, so use symmetrical logarithmic scale
if not isinstance(logscale, bool):
linear_scale = logscale
else:
linear_scale = 1.e-5
linear_threshold = min(abs(vmin), abs(vmax)) * linear_scale
norm = mpl.colors.SymLogNorm(linear_threshold, vmin=vmin, vmax=vmax)
elif norm is None:
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)

return norm


def plot_separatrix(da, sep_pos, ax, radial_coord='x'):
"""
Plots the separatrix as a black dotted line.
Expand Down

0 comments on commit 00f527d

Please sign in to comment.