Skip to content

Commit

Permalink
Merge pull request #3024 from jenshnielsen/handle_nan_filler_in_2d_plot
Browse files Browse the repository at this point in the history
Improved plotting of pre-shaped 2d data
  • Loading branch information
trevormorgan authored May 25, 2021
2 parents 8a65e59 + 0b4a53a commit 6dba3db
Show file tree
Hide file tree
Showing 5 changed files with 382 additions and 199 deletions.
37 changes: 26 additions & 11 deletions qcodes/dataset/data_export.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import logging
from typing import Any, Dict, List, Sequence, Tuple, Union, cast
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast

import numpy as np
from typing_extensions import TypedDict

from qcodes.dataset.data_set import DataSet, load_by_id
from qcodes.dataset.descriptions.param_spec import ParamSpecBase
Expand All @@ -10,11 +11,23 @@
log = logging.getLogger(__name__)


class DSPlotData(TypedDict):
"""
The dictionary used to represent data for use within `plot_dataset`
"""
name: str
unit: str
label: str
data: np.ndarray
shape: Optional[Tuple[int, ...]]


@deprecate(alternative="ndarray.flatten()")
def flatten_1D_data_for_plot(rawdata: Union[Sequence[Sequence[Any]],
np.ndarray]) -> np.ndarray:
"""
Cast the return value of the database query to
a numpy array
a 1D numpy array
Args:
rawdata: The return of the get_values function
Expand All @@ -23,16 +36,12 @@ def flatten_1D_data_for_plot(rawdata: Union[Sequence[Sequence[Any]],
A one-dimensional numpy array
"""
dataarray = np.array(rawdata)
shape = np.shape(dataarray)
dataarray = dataarray.reshape(np.product(shape))

dataarray = np.array(rawdata).flatten()
return dataarray


@deprecate(alternative="dataset.get_parameter_data")
def get_data_by_id(run_id: int) -> \
List[List[Dict[str, Union[str, np.ndarray]]]]:
def get_data_by_id(run_id: int) -> List[List[DSPlotData]]:
"""
Load data from database and reshapes into 1D arrays with minimal
name, unit and label metadata.
Expand Down Expand Up @@ -75,7 +84,9 @@ def get_data_by_id(run_id: int) -> \
return output


def _get_data_from_ds(ds: DataSet) -> List[List[Dict[str, Union[str, np.ndarray]]]]:
def _get_data_from_ds(
ds: DataSet,
) -> List[List[DSPlotData]]:
dependent_parameters: Tuple[ParamSpecBase, ...] = ds.dependent_parameters

all_data = ds.cache.data()
Expand All @@ -91,14 +102,18 @@ def _get_data_from_ds(ds: DataSet) -> List[List[Dict[str, Union[str, np.ndarray]
dependencies = ds.description.interdeps.dependencies[dependent]

for param_spec_base in dependencies + (dependent,):
my_data_dict: Dict[str, Union[str, np.ndarray]] = {
my_data_dict: DSPlotData = {
"name": param_spec_base.name,
"unit": param_spec_base.unit,
"label": param_spec_base.label,
"data": data_dict[param_spec_base.name].flatten(),
"data": data_dict[param_spec_base.name],
"shape": None,
}
data_dicts_list.append(my_data_dict)

if ds.description.shapes is not None:
data_dicts_list[-1]["shape"] = ds.description.shapes.get(dependent.name)

output.append(data_dicts_list)

return output
Expand Down
Loading

0 comments on commit 6dba3db

Please sign in to comment.