Skip to content

Commit

Permalink
Add wigner plotting (#495)
Browse files Browse the repository at this point in the history
* Add wigner plotting to program

* Move plotting to plot module

* Remove numpy import

* Fixes after code-review

* Tidy up wigner plotting

* Run black

* Add test

* Update colours

* Update changelog

* rename tests

* Update strawberryfields/plot.py

Co-authored-by: antalszava <antalszava@gmail.com>

* add contours arg

* Update tests

* Update changelog

* Apply suggestions from code review

Co-authored-by: antalszava <antalszava@gmail.com>

* Remove url

* fix string

* fix arg name

Co-authored-by: antalszava <antalszava@gmail.com>
  • Loading branch information
thisac and antalszava authored Dec 14, 2020
1 parent e54205f commit bbae4bd
Show file tree
Hide file tree
Showing 5 changed files with 200 additions and 1 deletion.
20 changes: 20 additions & 0 deletions .github/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,26 @@
* `TDMProgram` objects can now be compiled and submitted via the API.
[(#476)](https://github.com/XanaduAI/strawberryfields/pull/476)

* Wigner functions can be plotted directly via Strawberry Fields using Plot.ly.
[(#495)](https://github.com/XanaduAI/strawberryfields/pull/495)

```python
prog = sf.Program(1)
eng = sf.Engine('fock', backend_options={"cutoff_dim": 10})

with prog.context as q:
gamma = 2
Vgate(gamma) | q[0]

state = eng.run(prog).state

xvec = np.arange(-4, 4, 0.01)
pvec = np.arange(-4, 4, 0.01)
mode = 0

sf.plot_wigner(state, mode, xvec, pvec, renderer="browser")
```

* Strawberry Fields code can be generated from a program (and an engine) by
calling `sf.io.generate_code(program, eng=engine)`.
[(#496)](https://github.com/XanaduAI/strawberryfields/pull/496)
Expand Down
1 change: 1 addition & 0 deletions strawberryfields/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from .parameters import par_funcs as math
from .program import Program
from .tdm import TDMProgram
from .plot import plot_wigner
from . import tdm

__all__ = [
Expand Down
122 changes: 122 additions & 0 deletions strawberryfields/plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Copyright 2018-2020 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""
This module provides tools to visualize the state in various interactive
ways using Plot.ly.
"""
import numpy as np

plotly_error = (
"Plot.ly required for using this function. It can be installed as follows:"
"pip install plotly."
)


def _get_plotly():
"""Import Plot.ly on demand to avoid errors being raised unnecessarily."""
try:
# pylint:disable=import-outside-toplevel
import plotly.io as pio
except ImportError as e:
raise (plotly_error) from e
return pio


def plot_wigner(state, mode, xvec, pvec, renderer="browser", contours=True):
"""Plot the Wigner function with Plot.ly.
Args:
state (:class:`.BaseState`): the state used for plotting
mode (int): mode used to calculate the reduced Wigner function
xvec (array): array of discretized :math:`x` quadrature values
pvec (array): array of discretized :math:`p` quadrature values
renderer (string): the renderer for plotting with Plot.ly
contours (bool): whether to show the contour lines in the plot
"""
pio = _get_plotly()
pio.renderers.default = renderer

data = state.wigner(mode, xvec, pvec)
new_chart = generate_wigner_chart(data, xvec, pvec, contours=contours)
pio.show(new_chart)


def generate_wigner_chart(data, xvec, pvec, contours=True):
"""Populates a chart dictionary with reduced Wigner function surface plot data.
Args:
data (array): 2D array of size [len(xvec), len(pvec)], containing reduced
Wigner function values for specified x and p values.
xvec (array): array of discretized :math:`x` quadrature values
pvec (array): array of discretized :math:`p` quadrature values
contours (bool): whether to show the contour lines in the plot
Returns:
dict: a Plot.ly JSON-format surface plot
"""
textcolor = "#787878"

chart = {
"data": [
{
"type": "surface",
"colorscale": [],
"x": [],
"y": [],
"z": [],
"contours": {
"z": {},
},
}
],
"layout": {
"scene": {
"xaxis": {},
"yaxis": {},
"zaxis": {},
}
},
}

chart["data"][0]["type"] = "surface"
chart["data"][0]["colorscale"] = [
[0.0, "purple"],
[0.25, "red"],
[0.5, "yellow"],
[0.75, "green"],
[1.0, "blue"],
]

chart["data"][0]["x"] = xvec.tolist()
chart["data"][0]["y"] = pvec.tolist()
chart["data"][0]["z"] = data.tolist()

chart["data"][0]["contours"]["z"]["show"] = contours

chart["data"][0]["cmin"] = -1 / np.pi
chart["data"][0]["cmax"] = 1 / np.pi

chart["layout"]["paper_bgcolor"] = "white"
chart["layout"]["plot_bgcolor"] = "white"
chart["layout"]["font"] = {"color": textcolor}
chart["layout"]["scene"]["bgcolor"] = "white"

chart["layout"]["scene"]["xaxis"]["title"] = "x"
chart["layout"]["scene"]["xaxis"]["color"] = textcolor
chart["layout"]["scene"]["yaxis"]["title"] = "p"
chart["layout"]["scene"]["yaxis"]["color"] = textcolor
chart["layout"]["scene"]["yaxis"]["gridcolor"] = textcolor
chart["layout"]["scene"]["zaxis"]["title"] = "W(x,p)"

return chart
2 changes: 1 addition & 1 deletion strawberryfields/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@
import copy
import numbers
import warnings
import networkx as nx

import blackbird as bb
from blackbird.utils import match_template
import networkx as nx

import strawberryfields as sf

Expand Down
56 changes: 56 additions & 0 deletions tests/frontend/test_sf_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright 2019-2020 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""
Unit tests for strawberryfields.plot
"""
import pytest

import numpy as np
import plotly.io as pio

import strawberryfields as sf
from strawberryfields.ops import Sgate, BSgate, MeasureFock
from strawberryfields.plot import plot_wigner

pytestmark = pytest.mark.frontend

@pytest.fixture(scope="module")
def prog():
"""Program used for testing"""
program = sf.Program(2)

with program.context as q:
Sgate(0.54, 0) | q[0]
Sgate(0.54, 0) | q[1]
BSgate(6.283, 0.6283) | (q[0], q[1])
MeasureFock() | q

return program

class TestWignerPlotting:
"""Test the Wigner plotting function"""

@pytest.mark.parametrize("renderer", ["png", "json", "browser"])
@pytest.mark.parametrize("mode", [0, 1])
@pytest.mark.parametrize("contours", [True, False])
def test_no_errors(self, mode, renderer, contours, prog, monkeypatch):
"""Test that no errors are thrown when calling the `plot_wigner` function"""
eng = sf.Engine("gaussian")
results = eng.run(prog)

xvec = np.arange(-4, 4, 0.1)
pvec = np.arange(-4, 4, 0.1)
with monkeypatch.context() as m:
m.setattr(pio, "show", lambda x: None)
plot_wigner(results.state, mode, xvec, pvec, renderer=renderer, contours=contours)

0 comments on commit bbae4bd

Please sign in to comment.