-
Notifications
You must be signed in to change notification settings - Fork 192
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add wigner plotting to program * Move plotting to plot module * Remove numpy import * Fixes after code-review * Tidy up wigner plotting * Run black * Add test * Update colours * Update changelog * rename tests * Update strawberryfields/plot.py Co-authored-by: antalszava <antalszava@gmail.com> * add contours arg * Update tests * Update changelog * Apply suggestions from code review Co-authored-by: antalszava <antalszava@gmail.com> * Remove url * fix string * fix arg name Co-authored-by: antalszava <antalszava@gmail.com>
- Loading branch information
1 parent
e54205f
commit bbae4bd
Showing
5 changed files
with
200 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
# Copyright 2018-2020 Xanadu Quantum Technologies Inc. | ||
|
||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
|
||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
r""" | ||
This module provides tools to visualize the state in various interactive | ||
ways using Plot.ly. | ||
""" | ||
import numpy as np | ||
|
||
plotly_error = ( | ||
"Plot.ly required for using this function. It can be installed as follows:" | ||
"pip install plotly." | ||
) | ||
|
||
|
||
def _get_plotly(): | ||
"""Import Plot.ly on demand to avoid errors being raised unnecessarily.""" | ||
try: | ||
# pylint:disable=import-outside-toplevel | ||
import plotly.io as pio | ||
except ImportError as e: | ||
raise (plotly_error) from e | ||
return pio | ||
|
||
|
||
def plot_wigner(state, mode, xvec, pvec, renderer="browser", contours=True): | ||
"""Plot the Wigner function with Plot.ly. | ||
Args: | ||
state (:class:`.BaseState`): the state used for plotting | ||
mode (int): mode used to calculate the reduced Wigner function | ||
xvec (array): array of discretized :math:`x` quadrature values | ||
pvec (array): array of discretized :math:`p` quadrature values | ||
renderer (string): the renderer for plotting with Plot.ly | ||
contours (bool): whether to show the contour lines in the plot | ||
""" | ||
pio = _get_plotly() | ||
pio.renderers.default = renderer | ||
|
||
data = state.wigner(mode, xvec, pvec) | ||
new_chart = generate_wigner_chart(data, xvec, pvec, contours=contours) | ||
pio.show(new_chart) | ||
|
||
|
||
def generate_wigner_chart(data, xvec, pvec, contours=True): | ||
"""Populates a chart dictionary with reduced Wigner function surface plot data. | ||
Args: | ||
data (array): 2D array of size [len(xvec), len(pvec)], containing reduced | ||
Wigner function values for specified x and p values. | ||
xvec (array): array of discretized :math:`x` quadrature values | ||
pvec (array): array of discretized :math:`p` quadrature values | ||
contours (bool): whether to show the contour lines in the plot | ||
Returns: | ||
dict: a Plot.ly JSON-format surface plot | ||
""" | ||
textcolor = "#787878" | ||
|
||
chart = { | ||
"data": [ | ||
{ | ||
"type": "surface", | ||
"colorscale": [], | ||
"x": [], | ||
"y": [], | ||
"z": [], | ||
"contours": { | ||
"z": {}, | ||
}, | ||
} | ||
], | ||
"layout": { | ||
"scene": { | ||
"xaxis": {}, | ||
"yaxis": {}, | ||
"zaxis": {}, | ||
} | ||
}, | ||
} | ||
|
||
chart["data"][0]["type"] = "surface" | ||
chart["data"][0]["colorscale"] = [ | ||
[0.0, "purple"], | ||
[0.25, "red"], | ||
[0.5, "yellow"], | ||
[0.75, "green"], | ||
[1.0, "blue"], | ||
] | ||
|
||
chart["data"][0]["x"] = xvec.tolist() | ||
chart["data"][0]["y"] = pvec.tolist() | ||
chart["data"][0]["z"] = data.tolist() | ||
|
||
chart["data"][0]["contours"]["z"]["show"] = contours | ||
|
||
chart["data"][0]["cmin"] = -1 / np.pi | ||
chart["data"][0]["cmax"] = 1 / np.pi | ||
|
||
chart["layout"]["paper_bgcolor"] = "white" | ||
chart["layout"]["plot_bgcolor"] = "white" | ||
chart["layout"]["font"] = {"color": textcolor} | ||
chart["layout"]["scene"]["bgcolor"] = "white" | ||
|
||
chart["layout"]["scene"]["xaxis"]["title"] = "x" | ||
chart["layout"]["scene"]["xaxis"]["color"] = textcolor | ||
chart["layout"]["scene"]["yaxis"]["title"] = "p" | ||
chart["layout"]["scene"]["yaxis"]["color"] = textcolor | ||
chart["layout"]["scene"]["yaxis"]["gridcolor"] = textcolor | ||
chart["layout"]["scene"]["zaxis"]["title"] = "W(x,p)" | ||
|
||
return chart |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# Copyright 2019-2020 Xanadu Quantum Technologies Inc. | ||
|
||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
|
||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
r""" | ||
Unit tests for strawberryfields.plot | ||
""" | ||
import pytest | ||
|
||
import numpy as np | ||
import plotly.io as pio | ||
|
||
import strawberryfields as sf | ||
from strawberryfields.ops import Sgate, BSgate, MeasureFock | ||
from strawberryfields.plot import plot_wigner | ||
|
||
pytestmark = pytest.mark.frontend | ||
|
||
@pytest.fixture(scope="module") | ||
def prog(): | ||
"""Program used for testing""" | ||
program = sf.Program(2) | ||
|
||
with program.context as q: | ||
Sgate(0.54, 0) | q[0] | ||
Sgate(0.54, 0) | q[1] | ||
BSgate(6.283, 0.6283) | (q[0], q[1]) | ||
MeasureFock() | q | ||
|
||
return program | ||
|
||
class TestWignerPlotting: | ||
"""Test the Wigner plotting function""" | ||
|
||
@pytest.mark.parametrize("renderer", ["png", "json", "browser"]) | ||
@pytest.mark.parametrize("mode", [0, 1]) | ||
@pytest.mark.parametrize("contours", [True, False]) | ||
def test_no_errors(self, mode, renderer, contours, prog, monkeypatch): | ||
"""Test that no errors are thrown when calling the `plot_wigner` function""" | ||
eng = sf.Engine("gaussian") | ||
results = eng.run(prog) | ||
|
||
xvec = np.arange(-4, 4, 0.1) | ||
pvec = np.arange(-4, 4, 0.1) | ||
with monkeypatch.context() as m: | ||
m.setattr(pio, "show", lambda x: None) | ||
plot_wigner(results.state, mode, xvec, pvec, renderer=renderer, contours=contours) |