Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jitter Classes (#249) #253

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
114 changes: 94 additions & 20 deletions dLux/detector_layers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations
from abc import abstractmethod
import jax
import jax.numpy as np
from jax import Array
from jax.scipy.stats import norm
from jax.scipy.stats import norm, multivariate_normal
from zodiax import Base
import dLux

Expand Down Expand Up @@ -101,45 +102,117 @@ class ApplyJitter(DetectorLayer):

Attributes
----------
sigma : Array, pixels
The standard deviation of the Gaussian kernel, in units of pixels.
kernel_size : int
The size of the convolution kernel to use.
The size in pixels of the convolution kernel to use.
r : float, arcseconds
The magnitude of the jitter.
shear : float
The shear of the jitter.
phi : float, degrees
The angle of the jitter.
"""

kernel_size: int
sigma: Array

def __init__(self: DetectorLayer, sigma: Array, kernel_size: int = 10):
r: float = None
shear: float = None
phi: float = None

def __init__(
self: DetectorLayer,
r: float,
shear: float = 1,
phi: float = 0,
kernel_size: int = 10,
):
"""
Constructor for the ApplyJitter class.

Parameters
----------
sigma : Array, pixels
The standard deviation of the Gaussian kernel, in units of pixels.
r : float, arcseconds
The jitter magnitude in arcseconds, defined as the standard deviation
of the multivariate Gaussian kernel along the major axis. For a
symmetric jitter (shear = 1), r is simply the standard deviation.
shear : float
A measure of how asymmetric the jitter is. Defined as the ratio between
the standard deviations of the minor/major axes of the multivariate
Gaussian kernel. It must lie on the interval (0, 1]. A shear of 1
corresponds to a symmetric jitter, while as shear approaches zero the
jitter kernel becomes more linear.
phi : float
The angle of the jitter in degrees.
kernel_size : int = 10
The size of the convolution kernel to use.
The size of the convolution kernel in pixels to use.
"""
super().__init__()

# checking shear is valid
if shear > 1 or shear <= 0:
raise ValueError("shear must lie on the interval (0, 1]")

self.kernel_size = int(kernel_size)
self.sigma = np.asarray(sigma, dtype=float)
if self.sigma.ndim != 0:
raise ValueError("sigma must be a scalar array.")
self.r = r
self.shear = shear
self.phi = phi

def generate_kernel(self: DetectorLayer, pixel_scale: Array) -> Array:
@property
def covariance_matrix(self):
"""
Generates the covariance matrix for the multivariate normal distribution.

Returns
-------
covariance_matrix : Array
The covariance matrix.
"""
Generates the normalised Gaussian kernel.
# Compute the rotation angle
rot_angle = np.radians(self.phi)

# Construct the rotation matrix
R = np.array(
[
[np.cos(rot_angle), -np.sin(rot_angle)],
[np.sin(rot_angle), np.cos(rot_angle)],
]
)

# Construct the skew matrix
base_matrix = np.array(
[
[self.r**2, 0],
[0, (self.r * self.shear) ** 2],
]
)

# Compute the covariance matrix
covariance_matrix = np.dot(np.dot(R, base_matrix), R.T)

return covariance_matrix

def generate_kernel(self, pixel_scale: float) -> Array:
"""
Generates the normalised multivariate Gaussian kernel.

Parameters
----------
pixel_scale : float
The pixel scale of the image in arcseconds per pixel.

Returns
-------
kernel : Array
The Gaussian kernel.
The normalised Gaussian kernel.
"""
# Generate distribution
sigma = self.sigma * pixel_scale
x = np.linspace(-10, 10, self.kernel_size) * pixel_scale
kernel = norm.pdf(x, scale=sigma) * norm.pdf(x[:, None], scale=sigma)
extent = pixel_scale * self.kernel_size # kernel size in arcseconds
x = np.linspace(0, extent, self.kernel_size) - 0.5 * extent
xs, ys = np.meshgrid(x, x)
pos = np.dstack((xs, ys))

kernel = multivariate_normal.pdf(
pos, mean=np.array([0.0, 0.0]), cov=self.covariance_matrix
)

return kernel / np.sum(kernel)

def __call__(self: DetectorLayer, image: Image()) -> Image():
Expand All @@ -156,7 +229,8 @@ def __call__(self: DetectorLayer, image: Image()) -> Image():
image : Image
The transformed image.
"""
kernel = self.generate_kernel(image.pixel_scale)
kernel = self.generate_kernel(dLux.utils.rad_to_arcsec(image.pixel_scale))

return image.convolve(kernel)


Expand Down