Skip to content

Commit

Permalink
timeinterval refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
flowerthrower committed Nov 17, 2023
1 parent 76f461c commit df2ac64
Show file tree
Hide file tree
Showing 9 changed files with 78 additions and 53 deletions.
4 changes: 2 additions & 2 deletions src/qutip_qoc/analytical_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def optimize_pulses(
bounds : list of pairs of floats
[(lower, upper), ...]
Bounds for the pulse parameters.
tlist : array_like
tslots : array_like
List of times for the calculataion of final pulse sequence.
During integration only the first and last time are used.
kwargs : dict of dict
Expand Down Expand Up @@ -177,7 +177,7 @@ def optimize_pulses(
- atol, rtol : float
Absolute and relative tolerance of the ODE integrator.
- nsteps : int
Maximum number of (internally defined) steps allowed in one ``tlist``
Maximum number of (internally defined) steps allowed in one ``tslots``
step.
- max_step : float, 0
Maximum lenght of one internal step. When using pulses, it should be
Expand Down
File renamed without changes
File renamed without changes
File renamed without changes
34 changes: 16 additions & 18 deletions doc/seminar.ipynb → src/qutip_qoc/examples.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 23,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -14,7 +14,7 @@
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
"from diffrax import Dopri5, PIDController, Dopri8,Tsit5\n",
"from diffrax import Dopri5, PIDController, Dopri8\n",
"\n",
"from optimize import optimize_pulses\n",
"from time_interval import TimeInterval\n",
Expand Down Expand Up @@ -172,7 +172,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"<img src=\"documentation\\DYNAMO.PNG\" alt=\"dynamo\" width=\"600\"/>"
"<img src=\"doc_images\\DYNAMO.png\" alt=\"dynamo\" width=\"600\"/>"
]
},
{
Expand Down Expand Up @@ -213,7 +213,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"<img src=\"documentation\\CRAB.png\" alt=\"crab\" width=\"800\"/>"
"<img src=\"doc_images\\CRAB.png\" alt=\"crab\" width=\"800\"/>"
]
},
{
Expand Down Expand Up @@ -380,17 +380,15 @@
},
{
"cell_type": "code",
"execution_count": 24,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"initial = qt.qeye(2)\n",
"target = hadamard()\n",
"\n",
"initial = qt.sprepost(initial, initial.dag())\n",
"target = qt.sprepost(target , target.dag() )\n",
"\n",
"objective = Objective(initial, ... , target)"
"target = qt.sprepost(target , target.dag() )"
]
},
{
Expand Down Expand Up @@ -418,7 +416,7 @@
},
{
"cell_type": "code",
"execution_count": 25,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -436,18 +434,18 @@
},
{
"cell_type": "code",
"execution_count": 26,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"π = np.pi\n",
"num_ts = 100\n",
"interval = TimeInterval(evo_time= 2*π, num_tslots=num_ts)"
"interval = TimeInterval(evo_time= 2*π, n_tslots=num_ts)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -508,7 +506,7 @@
}
],
"source": [
"init_evo = qt.mesolve(H, initial, interval.tlist)\n",
"init_evo = qt.mesolve(H, initial, interval.tslots)\n",
"\n",
"qt.hinton(init_evo.final_state)"
]
Expand All @@ -524,7 +522,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"<img src=\"documentation\\QOC_structure.png\" alt=\"crab\" width=\"600\"/>"
"<img src=\"doc_images\\QOC.png\" alt=\"qoc\" width=\"600\"/>"
]
},
{
Expand Down Expand Up @@ -843,8 +841,8 @@
" for i in range(len(res.optimized_controls)):\n",
" ax[i].xaxis.set_label_text('Time')\n",
" ax[i].yaxis.set_label_text('Control ' + y_labels[i])\n",
" ax[i].plot(res.time_interval.tlist, res.guess_controls[i], label='Guess')\n",
" ax[i].plot(res.time_interval.tlist, res.optimized_controls[i], label='Optimized')\n",
" ax[i].plot(res.time_interval.tslots, res.guess_controls[i], label='Guess')\n",
" ax[i].plot(res.time_interval.tslots, res.optimized_controls[i], label='Optimized')\n",
" ax[i].legend()"
]
},
Expand Down Expand Up @@ -975,7 +973,7 @@
"q_init = [1, 1, 0] # q[0] * sin(q[1] * t + q[2])\n",
"r_init = [1, 1, 0]\n",
"\n",
"init_evo = qt.mesolve(H, initial, interval.tlist,\n",
"init_evo = qt.mesolve(H, initial, interval.tslots,\n",
" options={'normalize_output': False},\n",
" args={\"p\": p_init, \"q\": q_init, \"r\": r_init})\n",
"\n",
Expand Down Expand Up @@ -1420,7 +1418,7 @@
"n_var = 3\n",
"n_tot = n_sup * n_var\n",
"\n",
"interval = TimeInterval(evo_time= 2*π, num_tslots=1000)"
"interval = TimeInterval(evo_time= 2*π, n_tslots=1000)"
]
},
{
Expand Down
7 changes: 6 additions & 1 deletion src/qutip_qoc/objective.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import qutip as qt


class Objective:
"""
A class for storing all
Expand All @@ -19,9 +22,11 @@ class Objective:
----------
initial : :class:`qutip.Qobj`
The initial state or operator to be transformed.
H : callable, list or Qobj
H : callable, list
A specification of the time-depedent quantum object.
See :class:`qutip.QobjEvo` for details and examples.
target : :class:`qutip.Qobj`
The target state or operator.
"""
Expand Down
4 changes: 2 additions & 2 deletions src/qutip_qoc/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,9 @@ def optimize_pulses(objectives, pulse_options, time_interval, time_options={}, a
ctrls=Hc_lst,
initial=init,
target=targ,
num_tslots=time_interval.num_tslots,
n_tslots=time_interval.n_tslots,
evo_time=time_interval.evo_time,
tau=None, # implicitly derived from tlist
tau=None, # implicitly derived from tslots
amp_lbound=lbound,
amp_ubound=ubound,
fid_err_targ=algorithm_kwargs.get("fid_err_targ", 1e-10),
Expand Down
6 changes: 3 additions & 3 deletions src/qutip_qoc/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def optimized_controls(self):
control = Hc[1]
if callable(control):
cf = []
for t in self.time_interval.tlist:
for t in self.time_interval.tslots:
cf.append(control(t, xf))
else:
cf = xf
Expand All @@ -157,7 +157,7 @@ def guess_controls(self):
control = Hc[1]
if callable(control):
c0 = []
for t in self.time_interval.tlist:
for t in self.time_interval.tslots:
c0.append(control(t, x0))
else:
c0 = x0
Expand Down Expand Up @@ -217,7 +217,7 @@ def final_states(self):
qt.mesolve(
obj.H,
obj.initial,
tlist=[0., evo_time],
tslots=[0., evo_time],
args=args_dict,
options={'normalize_output': False}
).final_state
Expand Down
76 changes: 49 additions & 27 deletions src/qutip_qoc/time_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,48 +3,70 @@

class TimeInterval:
"""
Class for storing a time interval and deriving its attributes.
Attributes
----------
tslots : array_like, optional
List of time slots at which the control pulse is evaluated.
The last element of tslots is the total evolution time.
Can be unevenly spaced.
evo_time : float, optional
Total evolution time.
If given together with n_tslots, tslots is derived from evo_time
and assumed to be evenly spaced.
n_tslots : int, optional
Number of time slots. Length of tslots is n_tslots.
tdiffs : array_like, optional
List of time intervals between time slots.
Can be unevenly spaced.
Length of tdiffs is n_tslots - 1.
Sum over all elements of tdiffs is evo_time.
"""

def __init__(self, tlist=None, evo_time=None, num_tslots=None,
tau=None, bounds=None, guess=None):
self._tlist = tlist
self._tau = tau
def __init__(self, tslots=None, evo_time=None, n_tslots=None, tdiffs=None):
self._tslots = tslots
self._evo_time = evo_time
self._num_tslots = num_tslots
self._n_tslots = n_tslots
self._tdiffs = tdiffs

@property
def tlist(self):
if self._tlist is None:
n_tslots = self.num_tslots
def tslots(self):
if self._tslots is None:
n_tslots = self.n_tslots
if self._evo_time: # derive from evo_time
self._tlist = np.linspace(0., self._evo_time, n_tslots)
elif self._tau: # derive from tau
self._tlist = [sum(self._tau[:i]) for i in range(n_tslots - 1)]
return self._tlist
self._tslots = np.linspace(0., self._evo_time, n_tslots)
elif self._tdiffs: # derive from tdiffs
self._tslots = [sum(self._tdiffs[:i])
for i in range(n_tslots - 1)]
return self._tslots

@property
def tau(self):
if self._tau is None:
tlist = self.tlist
self._tau = np.diff(tlist)
return self._tau
def tdiffs(self):
if self._tdiffs is None:
tslots = self.tslots
self._tdiffs = np.diff(tslots)
return self._tdiffs

@property
def evo_time(self):
if self._evo_time is None:
tlist = self.tlist
self._evo_time = tlist[-1]
tslots = self.tslots
self._evo_time = tslots[-1]
return self._evo_time

@property
def num_tslots(self):
if self._num_tslots is None:
if self._tlist:
self._num_tslots = len(self._tlist)
elif self._tau:
self._num_tslots = len(self._tau) - 1
def n_tslots(self):
if self._n_tslots is None:
if self._tslots:
self._n_tslots = len(self._tslots)
elif self._tdiffs:
self._n_tslots = len(self._tdiffs) - 1
else:
raise ValueError(
"Either tlist, tau, or evo_time + num_tslots must be specified."
"Either tslots, tdiffs, or evo_time + n_tslots must be specified."
)
return self._num_tslots
return self._n_tslots

0 comments on commit df2ac64

Please sign in to comment.