Skip to content

Commit

Permalink
fix typechecking by using a typeddict
Browse files Browse the repository at this point in the history
  • Loading branch information
jenshnielsen committed May 18, 2021
1 parent ac12293 commit fa524a0
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 61 deletions.
21 changes: 14 additions & 7 deletions qcodes/dataset/data_export.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import logging
from typing import Any, Dict, List, Sequence, Tuple, Union, cast
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast

import numpy as np
from typing_extensions import TypedDict

from qcodes.dataset.data_set import DataSet, load_by_id
from qcodes.dataset.descriptions.param_spec import ParamSpecBase
Expand All @@ -10,6 +11,14 @@
log = logging.getLogger(__name__)


class DSPlotData(TypedDict):
name: str
unit: str
label: str
data: np.ndarray
shape: Optional[Tuple[int, ...]]


def flatten_1D_data_for_plot(rawdata: Union[Sequence[Sequence[Any]],
np.ndarray]) -> np.ndarray:
"""
Expand All @@ -31,8 +40,7 @@ def flatten_1D_data_for_plot(rawdata: Union[Sequence[Sequence[Any]],


@deprecate(alternative="dataset.get_parameter_data")
def get_data_by_id(run_id: int) -> \
List[List[Dict[str, Union[str, np.ndarray]]]]:
def get_data_by_id(run_id: int) -> List[List[DSPlotData]]:
"""
Load data from database and reshapes into 1D arrays with minimal
name, unit and label metadata.
Expand Down Expand Up @@ -77,7 +85,7 @@ def get_data_by_id(run_id: int) -> \

def _get_data_from_ds(
ds: DataSet,
) -> List[List[Dict[str, Union[None, str, np.ndarray, Tuple[int, ...]]]]]:
) -> List[List[DSPlotData]]:
dependent_parameters: Tuple[ParamSpecBase, ...] = ds.dependent_parameters

all_data = ds.cache.data()
Expand All @@ -93,18 +101,17 @@ def _get_data_from_ds(
dependencies = ds.description.interdeps.dependencies[dependent]

for param_spec_base in dependencies + (dependent,):
my_data_dict: Dict[str, Union[str, np.ndarray]] = {
my_data_dict: DSPlotData = {
"name": param_spec_base.name,
"unit": param_spec_base.unit,
"label": param_spec_base.label,
"data": data_dict[param_spec_base.name],
"shape": None,
}
data_dicts_list.append(my_data_dict)

if ds.description.shapes is not None:
data_dicts_list[-1]["shape"] = ds.description.shapes.get(dependent.name)
else:
data_dicts_list[-1]["shape"] = None

output.append(data_dicts_list)

Expand Down
107 changes: 53 additions & 54 deletions qcodes/dataset/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from qcodes.utils.plotting import auto_color_scale_from_config, find_scale_and_prefix

from .data_export import (
DSPlotData,
_get_data_from_ds,
_strings_as_ints,
flatten_1D_data_for_plot,
Expand All @@ -33,10 +34,9 @@
AxesTuple = Tuple[matplotlib.axes.Axes, matplotlib.colorbar.Colorbar]
AxesTupleList = Tuple[List[matplotlib.axes.Axes],
List[Optional[matplotlib.colorbar.Colorbar]]]
Number = Union[float, int]
# NamedData is the structure _get_data_from_ds returns and that plot_by_id
# uses internally
NamedData = List[List[Dict[str, Union[str, np.ndarray]]]]
NamedData = List[List[DSPlotData]]

# list of kwargs for plotting function, so that kwargs can be passed to
# :func:`plot_dataset` and will be distributed to the respective plotting func.
Expand Down Expand Up @@ -90,19 +90,19 @@ def heatmaphandler(**kwargs: Any) -> Any:
yield plot_handler_mapping[plottype](**kwargs.copy())


def plot_dataset(dataset: DataSet,
axes: Optional[Union[matplotlib.axes.Axes,
Sequence[matplotlib.axes.Axes]]] = None,
colorbars: Optional[Union[matplotlib.colorbar.Colorbar,
Sequence[
matplotlib.colorbar.Colorbar]]] = None,
rescale_axes: bool = True,
auto_color_scale: Optional[bool] = None,
cutoff_percentile: Optional[Union[Tuple[Number, Number],
Number]] = None,
complex_plot_type: str = 'real_and_imag',
complex_plot_phase: str = 'radians',
**kwargs: Any) -> AxesTupleList:
def plot_dataset(
dataset: DataSet,
axes: Optional[Union[matplotlib.axes.Axes, Sequence[matplotlib.axes.Axes]]] = None,
colorbars: Optional[
Union[matplotlib.colorbar.Colorbar, Sequence[matplotlib.colorbar.Colorbar]]
] = None,
rescale_axes: bool = True,
auto_color_scale: Optional[bool] = None,
cutoff_percentile: Optional[Union[Tuple[float, float], float]] = None,
complex_plot_type: str = "real_and_imag",
complex_plot_phase: str = "radians",
**kwargs: Any,
) -> AxesTupleList:
"""
Construct all plots for a given dataset
Expand Down Expand Up @@ -218,8 +218,8 @@ def plot_dataset(dataset: DataSet,
if len(data) == 2: # 1D PLOTTING
log.debug(f'Doing a 1D plot with kwargs: {kwargs}')

xpoints = cast(np.ndarray, data[0]['data'])
ypoints = cast(np.ndarray, data[1]['data'])
xpoints = data[0]["data"]
ypoints = data[1]["data"]

plottype = get_1D_plottype(xpoints, ypoints)
log.debug(f'Determined plottype: {plottype}')
Expand Down Expand Up @@ -305,19 +305,19 @@ def plot_dataset(dataset: DataSet,
return axeslist, new_colorbars


def plot_by_id(run_id: int,
axes: Optional[Union[matplotlib.axes.Axes,
Sequence[matplotlib.axes.Axes]]] = None,
colorbars: Optional[Union[matplotlib.colorbar.Colorbar,
Sequence[
matplotlib.colorbar.Colorbar]]] = None,
rescale_axes: bool = True,
auto_color_scale: Optional[bool] = None,
cutoff_percentile: Optional[Union[Tuple[Number, Number],
Number]] = None,
complex_plot_type: str = 'real_and_imag',
complex_plot_phase: str = 'radians',
**kwargs: Any) -> AxesTupleList:
def plot_by_id(
run_id: int,
axes: Optional[Union[matplotlib.axes.Axes, Sequence[matplotlib.axes.Axes]]] = None,
colorbars: Optional[
Union[matplotlib.colorbar.Colorbar, Sequence[matplotlib.colorbar.Colorbar]]
] = None,
rescale_axes: bool = True,
auto_color_scale: Optional[bool] = None,
cutoff_percentile: Optional[Union[Tuple[float, float], float]] = None,
complex_plot_type: str = "real_and_imag",
complex_plot_phase: str = "radians",
**kwargs: Any,
) -> AxesTupleList:
"""
Construct all plots for a given `run_id`. Here `run_id` is an
alias for `captured_run_id` for historical reasons. See the docs
Expand Down Expand Up @@ -358,7 +358,7 @@ def _complex_to_real_preparser(alldata: NamedData,
'but can only accept "real_and_imag" or '
'"mag_and_phase".')

newdata = []
newdata: NamedData = []

# we build a new NamedData object from the given `alldata` input.
# Note that the length of `newdata` will be larger than that of `alldata`
Expand All @@ -371,11 +371,11 @@ def _complex_to_real_preparser(alldata: NamedData,
new_group = []
new_groups: NamedData = [[], []]
for index, parameter in enumerate(group):
data = cast(np.ndarray, parameter['data'])
if data.dtype.kind == 'c':
p1, p2 = _convert_complex_to_real(parameter,
conversion=conversion,
degrees=degrees)
data = parameter["data"]
if data.dtype.kind == "c":
p1, p2 = _convert_complex_to_real(
parameter, conversion=conversion, degrees=degrees
)
if index < len(group) - 1:
# if the above condition is met, we are dealing with
# complex setpoints
Expand Down Expand Up @@ -405,11 +405,8 @@ def _complex_to_real_preparser(alldata: NamedData,


def _convert_complex_to_real(
parameter: Dict[str, Union[str, np.ndarray]],
conversion: str,
degrees: bool
) -> Tuple[Dict[str, Union[str, np.ndarray]],
Dict[str, Union[str, np.ndarray]]]:
parameter: DSPlotData, conversion: str, degrees: bool
) -> Tuple[DSPlotData, DSPlotData]:
"""
Do the actual conversion and turn one parameter into two.
Should only be called from within _complex_to_real_preparser.
Expand Down Expand Up @@ -446,7 +443,7 @@ def _convert_complex_to_real(
return new_parameters # type: ignore[return-value]


def _get_label_of_data(data_dict: Dict[str, Any]) -> str:
def _get_label_of_data(data_dict: DSPlotData) -> str:
return data_dict['label'] if data_dict['label'] != '' \
else data_dict['name']

Expand All @@ -458,17 +455,17 @@ def _make_axis_label(label: str, unit: str) -> str:
return label


def _make_label_for_data_axis(data: List[Dict[str, Any]], axis_index: int
) -> str:
def _make_label_for_data_axis(data: Sequence[DSPlotData], axis_index: int) -> str:
label = _get_label_of_data(data[axis_index])
unit = data[axis_index]['unit']
return _make_axis_label(label, unit)


def _set_data_axes_labels(ax: matplotlib.axes.Axes,
data: List[Dict[str, Any]],
cax: Optional[matplotlib.colorbar.Colorbar] = None
) -> None:
def _set_data_axes_labels(
ax: matplotlib.axes.Axes,
data: Sequence[DSPlotData],
cax: Optional[matplotlib.colorbar.Colorbar] = None,
) -> None:
ax.set_xlabel(_make_label_for_data_axis(data, 0))
ax.set_ylabel(_make_label_for_data_axis(data, 1))

Expand Down Expand Up @@ -665,8 +662,9 @@ def _scale_formatter(tick_value: float, pos: int, factor: float) -> str:
return "{:g}".format(tick_value*factor)


def _make_rescaled_ticks_and_units(data_dict: Dict[str, Any]) \
-> Tuple[matplotlib.ticker.FuncFormatter, str]:
def _make_rescaled_ticks_and_units(
data_dict: DSPlotData,
) -> Tuple[matplotlib.ticker.FuncFormatter, str]:
"""
Create a ticks formatter and a new label for the data that is to be used
on the axes where the data is plotted.
Expand Down Expand Up @@ -710,10 +708,11 @@ def _make_rescaled_ticks_and_units(data_dict: Dict[str, Any]) \
return ticks_formatter, new_label


def _rescale_ticks_and_units(ax: matplotlib.axes.Axes,
data: List[Dict[str, Any]],
cax: matplotlib.colorbar.Colorbar = None
) -> None:
def _rescale_ticks_and_units(
ax: matplotlib.axes.Axes,
data: Sequence[DSPlotData],
cax: matplotlib.colorbar.Colorbar = None,
) -> None:
"""
Rescale ticks and units for the provided axes as described in
:func:`~_make_rescaled_ticks_and_units`
Expand Down

0 comments on commit fa524a0

Please sign in to comment.