Skip to content

Commit

Permalink
Fix pulse mypy errors
Browse files Browse the repository at this point in the history
  • Loading branch information
Randl committed Jun 29, 2022
1 parent 1082240 commit be510fb
Show file tree
Hide file tree
Showing 15 changed files with 81 additions and 61 deletions.
10 changes: 5 additions & 5 deletions qiskit/pulse/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,7 @@ def get_context(self) -> ScheduleBlock:
"""
return self._context_stack[-1]

@property
@property # type: ignore
@_requires_backend
def num_qubits(self):
"""Get the number of qubits in the backend."""
Expand All @@ -664,7 +664,7 @@ def transpiler_settings(self) -> Mapping:
"""The builder's transpiler settings."""
return self._transpiler_settings

@transpiler_settings.setter
@transpiler_settings.setter # type: ignore
@_compile_lazy_circuit_before
def transpiler_settings(self, settings: Mapping):
self._compile_lazy_circuit()
Expand All @@ -675,7 +675,7 @@ def circuit_scheduler_settings(self) -> Mapping:
"""The builder's circuit to pulse scheduler settings."""
return self._circuit_scheduler_settings

@circuit_scheduler_settings.setter
@circuit_scheduler_settings.setter # type: ignore
@_compile_lazy_circuit_before
def circuit_scheduler_settings(self, settings: Mapping):
self._compile_lazy_circuit()
Expand Down Expand Up @@ -808,7 +808,7 @@ def call_subroutine(
self.append_instruction(call_def)

@_requires_backend
def call_gate(self, gate: circuit.Gate, qubits: Tuple[int, ...], lazy: bool = True):
def call_gate(self, gate: circuit.Gate, qubits: Union[int, Tuple[int, ...]], lazy: bool = True):
"""Call the circuit ``gate`` in the pulse program.
The qubits are assumed to be defined on physical qubits.
Expand Down Expand Up @@ -2273,7 +2273,7 @@ def delay_qubits(duration: int, *qubits: Union[int, Iterable[int]]):


# Gate instructions
def call_gate(gate: circuit.Gate, qubits: Tuple[int, ...], lazy: bool = True):
def call_gate(gate: circuit.Gate, qubits: Union[int, Tuple[int, ...]], lazy: bool = True):
"""Call a gate and lazily schedule it to its corresponding
pulse instruction.
Expand Down
6 changes: 4 additions & 2 deletions qiskit/pulse/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
.. autoclass:: Channel
"""
from abc import ABCMeta
from typing import Any, Set, Union
from typing import Any, Set, Union, Optional

import numpy as np

Expand Down Expand Up @@ -143,7 +143,7 @@ def name(self) -> str:
def __repr__(self):
return f"{self.__class__.__name__}({self._index})"

def __eq__(self, other: "Channel") -> bool:
def __eq__(self, other: object) -> bool:
"""Return True iff self and other are equal, specifically, iff they have the same type
and the same index.
Expand All @@ -153,6 +153,8 @@ def __eq__(self, other: "Channel") -> bool:
Returns:
True iff equal.
"""
if not isinstance(other, Channel):
return NotImplemented
return type(self) is type(other) and self._index == other._index

def __hash__(self):
Expand Down
17 changes: 11 additions & 6 deletions qiskit/pulse/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""
from typing import Dict, Union, Tuple, Optional

from .channels import PulseChannel, DriveChannel, MeasureChannel
from .channels import DriveChannel, MeasureChannel
from .exceptions import PulseError


Expand Down Expand Up @@ -118,8 +118,10 @@ class LoConfig:

def __init__(
self,
channel_los: Optional[Dict[PulseChannel, float]] = None,
lo_ranges: Optional[Dict[PulseChannel, Union[LoRange, Tuple[int]]]] = None,
channel_los: Optional[Dict[Union[DriveChannel, MeasureChannel], float]] = None,
lo_ranges: Optional[
Dict[Union[DriveChannel, MeasureChannel], Union[LoRange, Tuple[int]]]
] = None,
):
"""Lo channel configuration data structure.
Expand All @@ -131,9 +133,9 @@ def __init__(
PulseError: If channel is not configurable or set lo is out of range.
"""
self._q_lo_freq = {}
self._m_lo_freq = {}
self._lo_ranges = {}
self._q_lo_freq: Dict[Union[DriveChannel, MeasureChannel], float] = {}
self._m_lo_freq: Dict[Union[DriveChannel, MeasureChannel], float] = {}
self._lo_ranges: Dict[Union[DriveChannel, MeasureChannel], LoRange] = {}

lo_ranges = lo_ranges if lo_ranges else {}
for channel, freq in lo_ranges.items():
Expand Down Expand Up @@ -176,12 +178,15 @@ def check_lo(self, channel: Union[DriveChannel, MeasureChannel], freq: float) ->
freq: lo frequency
Raises:
PulseError: If freq is outside of channels range
Returns:
True if lo is valid for channel
"""
lo_ranges = self._lo_ranges
if channel in lo_ranges:
lo_range = lo_ranges[channel]
if not lo_range.includes(freq):
raise PulseError(f"Specified LO freq {freq:f} is out of range {lo_range}")
return True

def channel_lo(self, channel: Union[DriveChannel, MeasureChannel]) -> float:
"""Return channel lo.
Expand Down
2 changes: 1 addition & 1 deletion qiskit/pulse/instructions/call.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def _get_arg_hash(self):
"""A helper function to generate hash of parameters."""
return hash(tuple(self.arguments.items()))

def __eq__(self, other: "Instruction") -> bool:
def __eq__(self, other: object) -> bool:
"""Check if this instruction is equal to the `other` instruction.
Instructions are equal if they share the same type, operands, and channels.
Expand Down
14 changes: 8 additions & 6 deletions qiskit/pulse/instructions/instruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ def __init__(
PulseError: If the input ``channels`` are not all of
type :class:`Channel`.
"""
self._operands = operands
self._operands: Tuple = operands
self._name = name
self._hash = None
self._hash: Optional[int] = None

for channel in self.channels:
if not isinstance(channel, Channel):
Expand All @@ -75,7 +75,7 @@ def operands(self) -> Tuple:
"""Return instruction operands."""
return self._operands

@property
@property # type: ignore
@abstractmethod
def channels(self) -> Tuple[Channel]:
"""Returns the channels that this schedule uses."""
Expand All @@ -97,12 +97,12 @@ def duration(self) -> int:
raise NotImplementedError

@property
def _children(self) -> Tuple["Instruction"]:
def _children(self) -> Tuple["Instruction", ...]:
"""Instruction has no child nodes."""
return ()

@property
def instructions(self) -> Tuple[Tuple[int, "Instruction"]]:
def instructions(self) -> Tuple[Tuple[int, "Instruction"], ...]:
"""Iterable for getting instructions from Schedule tree."""
return tuple(self._instructions())

Expand Down Expand Up @@ -262,11 +262,13 @@ def draw(
channels=channels,
)

def __eq__(self, other: "Instruction") -> bool:
def __eq__(self, other: object) -> bool:
"""Check if this Instruction is equal to the `other` instruction.
Equality is determined by the instruction sharing the same operands and channels.
"""
if not isinstance(other, Instruction):
return NotImplemented
return isinstance(other, type(self)) and self.operands == other.operands

def __hash__(self) -> int:
Expand Down
20 changes: 10 additions & 10 deletions qiskit/pulse/library/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,14 +105,14 @@ def sin(times: np.ndarray, amp: complex, freq: float, phase: float = 0) -> np.nd


def _fix_gaussian_width(
gaussian_samples,
amp: float,
gaussian_samples: np.ndarray,
amp: complex,
center: float,
sigma: float,
zeroed_width: Optional[float] = None,
rescale_amp: bool = False,
ret_scale_factor: bool = False,
) -> np.ndarray:
) -> Union[np.ndarray, Tuple[np.ndarray, float]]:
r"""Enforce that the supplied gaussian pulse is zeroed at a specific width.
This is achieved by subtracting $\Omega_g(center \pm zeroed_width/2)$ from all samples.
Expand All @@ -132,7 +132,7 @@ def _fix_gaussian_width(

zero_offset = gaussian(np.array([zeroed_width / 2]), amp, 0, sigma)
gaussian_samples -= zero_offset
amp_scale_factor = 1.0
amp_scale_factor: Union[complex, float, np.ndarray] = 1.0
if rescale_amp:
amp_scale_factor = amp / (amp - zero_offset) if amp - zero_offset != 0 else 1.0
gaussian_samples *= amp_scale_factor
Expand Down Expand Up @@ -198,7 +198,7 @@ def gaussian_deriv(
ret_gaussian: bool = False,
zeroed_width: Optional[float] = None,
rescale_amp: bool = False,
) -> np.ndarray:
) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
r"""Continuous unnormalized gaussian derivative pulse.
Args:
Expand Down Expand Up @@ -229,14 +229,14 @@ def gaussian_deriv(


def _fix_sech_width(
sech_samples,
amp: float,
sech_samples: np.ndarray,
amp: complex,
center: float,
sigma: float,
zeroed_width: Optional[float] = None,
rescale_amp: bool = False,
ret_scale_factor: bool = False,
) -> np.ndarray:
) -> Union[np.ndarray, Tuple[np.ndarray, float]]:
r"""Enforce that the supplied sech pulse is zeroed at a specific width.
This is achieved by subtracting $\Omega_g(center \pm zeroed_width/2)$ from all samples.
Expand All @@ -256,7 +256,7 @@ def _fix_sech_width(

zero_offset = sech(np.array([zeroed_width / 2]), amp, 0, sigma)
sech_samples -= zero_offset
amp_scale_factor = 1.0
amp_scale_factor: Union[complex, float, np.ndarray] = 1.0
if rescale_amp:
amp_scale_factor = amp / (amp - zero_offset) if amp - zero_offset != 0 else 1.0
sech_samples *= amp_scale_factor
Expand Down Expand Up @@ -316,7 +316,7 @@ def sech(

def sech_deriv(
times: np.ndarray, amp: complex, center: float, sigma: float, ret_sech: bool = False
) -> np.ndarray:
) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
"""Continuous unnormalized sech derivative pulse.
Args:
Expand Down
4 changes: 3 additions & 1 deletion qiskit/pulse/library/parametric_pulses.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,9 @@ def is_parameterized(self) -> bool:
"""Return True iff the instruction is parameterized."""
return any(_is_parameterized(val) for val in self.parameters.values())

def __eq__(self, other: Pulse) -> bool:
def __eq__(self, other: object) -> bool:
if not isinstance(other, Pulse):
return NotImplemented
return super().__eq__(other) and self.parameters == other.parameters

def __hash__(self) -> int:
Expand Down
6 changes: 4 additions & 2 deletions qiskit/pulse/library/pulse.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def id(self) -> int: # pylint: disable=invalid-name
"""Unique identifier for this pulse."""
return id(self)

@property
@property # type: ignore
@abstractmethod
def parameters(self) -> Dict[str, Any]:
"""Return a dictionary containing the pulse's parameters."""
Expand Down Expand Up @@ -123,7 +123,9 @@ def draw(
)

@abstractmethod
def __eq__(self, other: "Pulse") -> bool:
def __eq__(self, other: object) -> bool:
if not isinstance(other, Pulse):
return NotImplemented
return isinstance(other, type(self))

@abstractmethod
Expand Down
12 changes: 7 additions & 5 deletions qiskit/pulse/library/symbolic_pulses.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,9 @@ def _is_amplitude_valid(symbolic_pulse: "SymbolicPulse") -> bool:
return False


def _get_expression_args(expr: sym.Expr, params: Dict[str, float]) -> List[float]:
def _get_expression_args(
expr: sym.Expr, params: Dict[str, float]
) -> List[Union[np.ndarray, float]]:
"""A helper function to get argument to evaluate expression.
Args:
Expand All @@ -114,7 +116,7 @@ def _get_expression_args(expr: sym.Expr, params: Dict[str, float]) -> List[float
Raises:
PulseError: When a free symbol value is not defined in the pulse instance parameters.
"""
args = []
args: List[Union[np.ndarray, float]] = []
for symbol in sorted(expr.free_symbols, key=lambda s: s.name):
if symbol.name == "t":
# 't' is a special parameter to represent time vector.
Expand Down Expand Up @@ -156,7 +158,7 @@ def __init__(self, attribute: str):
the target expression to evaluate.
"""
self.attribute = attribute
self.lambda_funcs = dict()
self.lambda_funcs: Dict[int, Callable] = dict()

def __get__(self, instance, owner) -> Callable:
expr = getattr(instance, self.attribute, None)
Expand Down Expand Up @@ -543,11 +545,11 @@ def is_parameterized(self) -> bool:

@property
def parameters(self) -> Dict[str, Any]:
params = {"duration": self.duration}
params: Dict[str, Union[ParameterExpression, complex, int]] = {"duration": self.duration}
params.update(self._params)
return params

def __eq__(self, other: "SymbolicPulse") -> bool:
def __eq__(self, other: object) -> bool:

if not isinstance(other, SymbolicPulse):
return NotImplemented
Expand Down
6 changes: 4 additions & 2 deletions qiskit/pulse/library/waveform.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,17 @@ def parameters(self) -> Dict[str, Any]:
"""Return a dictionary containing the pulse's parameters."""
return {}

def __eq__(self, other: Pulse) -> bool:
def __eq__(self, other: object) -> bool:
if not isinstance(other, Pulse):
return NotImplemented
return (
super().__eq__(other)
and self.samples.shape == other.samples.shape
and np.allclose(self.samples, other.samples, rtol=0, atol=self.epsilon)
)

def __hash__(self) -> int:
return hash(self.samples.tostring())
return hash(self.samples.tobytes())

def __repr__(self) -> str:
opt = np.get_printoptions()
Expand Down
4 changes: 2 additions & 2 deletions qiskit/pulse/macros.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@

"""Module for common pulse programming macros."""

from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Union, Sequence

from qiskit.pulse import channels, exceptions, instructions, utils
from qiskit.pulse.instruction_schedule_map import InstructionScheduleMap
from qiskit.pulse.schedule import Schedule


def measure(
qubits: List[int],
qubits: Sequence[int],
backend=None,
inst_map: Optional[InstructionScheduleMap] = None,
meas_map: Optional[Union[List[List[int]], Dict[int, List[int]]]] = None,
Expand Down
Loading

0 comments on commit be510fb

Please sign in to comment.