Skip to content
This repository has been archived by the owner on Dec 7, 2021. It is now read-only.

Commit

Permalink
Merge pull request #812 from Zoufalc/master
Browse files Browse the repository at this point in the history
qGAN comments
  • Loading branch information
manoelmarques authored Feb 6, 2020
2 parents d0bdae9 + 6349b16 commit 62311df
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 16 deletions.
17 changes: 12 additions & 5 deletions qiskit/aqua/algorithms/adaptive/qgan/qgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

"""
Quantum Generative Adversarial Network.
`Quantum Generative Adversarial Networks for learning and loading random distributions
<https://www.nature.com/articles/s41534-019-0223-2>`_
"""

from typing import Optional
Expand Down Expand Up @@ -41,6 +43,11 @@
class QGAN(QuantumAlgorithm):
"""
Quantum Generative Adversarial Network.
This adaptive algorithm uses the interplay of a generative
:class:`~qiskit.aqua.components.neural_networks.GenerativeNetwork`and a
discriminative :class:`~qiskit.aqua.components.neural_networks.DiscriminativeNetwork`
network to learn the probability distribution underlying given training data.
"""

Expand Down Expand Up @@ -131,7 +138,7 @@ def __init__(self, data: np.ndarray, bounds: Optional[np.ndarray] = None,

@property
def seed(self):
""" returns seed """
""" returns random seed """
return self._random_seed

@seed.setter
Expand Down Expand Up @@ -206,21 +213,21 @@ def set_discriminator(self, discriminator=None):

@property
def g_loss(self):
""" returns g loss """
""" returns generator loss """
return self._g_loss

@property
def d_loss(self):
""" returns d loss """
""" returns discriminator loss """
return self._d_loss

@property
def rel_entr(self):
""" returns relative entropy """
""" returns relative entropy between target and trained distribution """
return self._rel_entr

def get_rel_entr(self):
""" get relative entropy """
""" get relative entropy between target and trained distribution """
samples_gen, prob_gen = self._generator.get_output(self._quantum_instance)
temp = np.zeros(len(self._grid_elements))
for j, sample in enumerate(samples_gen):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

""" Discriminative Quantum or Classical Neural Networks. """
""" Discriminative Quantum or Classical Neural Networks."""

from abc import ABC, abstractmethod

Expand Down
6 changes: 3 additions & 3 deletions qiskit/aqua/components/neural_networks/generative_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# This code is part of Qiskit.
#
# (C) Copyright IBM 2019.
# (C) Copyright IBM 2019, 2020.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
Expand All @@ -12,7 +12,7 @@
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

""" Generative Quantum and Classical Neural Networks. """
""" Generative Quantum and Classical Neural Networks."""

from abc import ABC, abstractmethod

Expand Down Expand Up @@ -51,7 +51,7 @@ def get_output(self, quantum_instance, qc_state_in, params, shots):
Args:
quantum_instance (QuantumInstance): Quantum Instance, used to run the generator circuit.
qc_state_in (QuantumCircuit): corresponding to the input state
qc_state_in (QuantumCircuit or vector): corresponding to the network input state
params (numpy.ndarray): parameters which should be used to run the generator,
if None use self._params
shots (int): if not None use a number of shots that is different from the number
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# that they have been altered from the originals.

"""
Discriminator
PyTorch Discriminator Neural Network
"""

import os
Expand Down Expand Up @@ -85,7 +85,7 @@ def save_model(self, snapshot_dir: str):

def load_model(self, load_dir: str):
"""
Save discriminator model
Load discriminator model
Args:
load_dir: file with stored pytorch discriminator model to be loaded
Expand Down Expand Up @@ -180,7 +180,7 @@ def gradient_penalty(self, x, lambda_=5., k=0.01, c=1.):

def train(self, data, weights, penalty=True, quantum_instance=None, shots=None):
"""
Perform one training step w.r.t to the discriminator's parameters
Perform one training step w.r.t. to the discriminator's parameters
Args:
data (tuple):
Expand Down
21 changes: 17 additions & 4 deletions qiskit/aqua/components/neural_networks/quantum_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
# that they have been altered from the originals.

"""
Generator
Quantum Generator
"""

from typing import Optional, List, Union
Expand All @@ -38,6 +39,14 @@
class QuantumGenerator(GenerativeNetwork):
"""
Quantum Generator
The quantum generator is a parametrized quantum circuit which can be trained with the
:class:`~qiskit.aqua.algorithms.adaptive.qgan.QGAN` algorithm
to generate a quantum state which approximates the probability
distribution of given training data. At the beginning of the training the parameters will
be set randomly, thus, the output will is random. Throughout the training the quantum
generator learns to represent the target distribution.
Eventually, the trained generator can be used for state preparation in e.g. QAE.
"""

def __init__(self,
Expand Down Expand Up @@ -178,7 +187,7 @@ def set_seed(self, seed):

def set_discriminator(self, discriminator):
"""
Set discriminator
Set discriminator network.
Args:
discriminator (Discriminator): Discriminator used to compute the loss function.
Expand Down Expand Up @@ -211,7 +220,11 @@ def construct_circuit(self, params=None):

def get_output(self, quantum_instance, qc_state_in=None, params=None, shots=None):
"""
Get data samples from the generator.
Get classical data samples from the generator.
Running the quantum generator circuit results in a quantum state.
To train this generator with a classical discriminator, we need to sample classical outputs
by measuring the quantum state and mapping them to feature space defined by the training
data.
Args:
quantum_instance (QuantumInstance): Quantum Instance, used to run the generator
Expand Down Expand Up @@ -279,7 +292,7 @@ def get_output(self, quantum_instance, qc_state_in=None, params=None, shots=None

def loss(self, x, weights): # pylint: disable=arguments-differ
"""
Loss function
Loss function for training the generator's parameters.
Args:
x (numpy.ndarray): sample label (equivalent to discriminator output)
Expand Down

0 comments on commit 62311df

Please sign in to comment.