Skip to content

Commit

Permalink
Add FFT Ops (keras-team#480)
Browse files Browse the repository at this point in the history
* Add FFT Ops

* Fixes

* Fix torch

* Address Matt's comments

* Address Francois' comments

* Shift docstrings to correct fns

* Add NumPy backend FFT ops

* Fix numpy backend

* Minor change

* Redirect NumPy FFT to JAX
  • Loading branch information
abheesht17 authored and adi-kmt committed Jul 21, 2023
1 parent 9c75bd0 commit a19a28e
Show file tree
Hide file tree
Showing 6 changed files with 332 additions and 0 deletions.
40 changes: 40 additions & 0 deletions keras_core/backend/jax/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,43 @@ def qr(x, mode="reduced"):
f"Received: mode={mode}"
)
return jax.numpy.linalg.qr(x, mode=mode)


def _get_complex_tensor_from_tuple(a):
if not isinstance(a, (tuple, list)) or len(a) != 2:
raise ValueError(
"Input `a` should be a tuple of two tensors - real and imaginary."
f"Received: a={a}"
)
# `convert_to_tensor` does not support passing complex tensors. We separate
# the input out into real and imaginary and convert them separately.
real, imag = a
# Check shapes.
if real.shape != imag.shape:
raise ValueError(
"Input `a` should be a tuple of two tensors - real and imaginary."
"Both the real and imaginary parts should have the same shape. "
f"Received: a[0].shape = {real.shape}, a[1].shape = {imag.shape}"
)
# Ensure dtype is float.
if not jnp.issubdtype(real.dtype, jnp.floating) or not jnp.issubdtype(
imag.dtype, jnp.floating
):
raise ValueError(
"At least one tensor in input `a` is not of type float."
f"Received: a={a}."
)
complex_input = jax.lax.complex(real, imag)
return complex_input


def fft(a):
complex_input = _get_complex_tensor_from_tuple(a)
complex_output = jax.numpy.fft.fft(complex_input)
return jax.numpy.real(complex_output), jax.numpy.imag(complex_output)


def fft2(a):
complex_input = _get_complex_tensor_from_tuple(a)
complex_output = jax.numpy.fft.fft2(complex_input)
return jax.numpy.real(complex_output), jax.numpy.imag(complex_output)
13 changes: 13 additions & 0 deletions keras_core/backend/numpy/math.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import numpy as np

from keras_core.backend.jax.math import fft as jax_fft
from keras_core.backend.jax.math import fft2 as jax_fft2


def segment_sum(data, segment_ids, num_segments=None, sorted=False):
if num_segments is None:
Expand Down Expand Up @@ -74,3 +77,13 @@ def qr(x, mode="reduced"):
f"Received: mode={mode}"
)
return np.linalg.qr(x, mode=mode)


def fft(a):
real, imag = jax_fft(a)
return np.array(real), np.array(imag)


def fft2(a):
real, imag = jax_fft2(a)
return np.array(real), np.array(imag)
42 changes: 42 additions & 0 deletions keras_core/backend/tensorflow/math.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import tensorflow as tf

from keras_core.backend.tensorflow.core import convert_to_tensor


def segment_sum(data, segment_ids, num_segments=None, sorted=False):
if sorted:
Expand Down Expand Up @@ -33,3 +35,43 @@ def qr(x, mode="reduced"):
if mode == "reduced":
return tf.linalg.qr(x)
return tf.linalg.qr(x, full_matrices=True)


def _get_complex_tensor_from_tuple(a):
if not isinstance(a, (tuple, list)) or len(a) != 2:
raise ValueError(
"Input `a` should be a tuple of two tensors - real and imaginary."
f"Received: a={a}"
)
# `convert_to_tensor` does not support passing complex tensors. We separate
# the input out into real and imaginary and convert them separately.
real, imag = a
real = convert_to_tensor(real)
imag = convert_to_tensor(imag)
# Check shapes.
if real.shape != imag.shape:
raise ValueError(
"Input `a` should be a tuple of two tensors - real and imaginary."
"Both the real and imaginary parts should have the same shape. "
f"Received: a[0].shape = {real.shape}, a[1].shape = {imag.shape}"
)
# Ensure dtype is float.
if not real.dtype.is_floating or not imag.dtype.is_floating:
raise ValueError(
"At least one tensor in input `a` is not of type float."
f"Received: a={a}."
)
complex_input = tf.dtypes.complex(real, imag)
return complex_input


def fft(a):
complex_input = _get_complex_tensor_from_tuple(a)
complex_output = tf.signal.fft(complex_input)
return tf.math.real(complex_output), tf.math.imag(complex_output)


def fft2(a):
complex_input = _get_complex_tensor_from_tuple(a)
complex_output = tf.signal.fft2d(complex_input)
return tf.math.real(complex_output), tf.math.imag(complex_output)
41 changes: 41 additions & 0 deletions keras_core/backend/torch/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,44 @@ def qr(x, mode="reduced"):
)
x = convert_to_tensor(x)
return torch.linalg.qr(x, mode=mode)


def _get_complex_tensor_from_tuple(a):
if not isinstance(a, (tuple, list)) or len(a) != 2:
raise ValueError(
"Input `a` should be a tuple of two tensors - real and imaginary."
f"Received: a={a}"
)
# `convert_to_tensor` does not support passing complex tensors. We separate
# the input out into real and imaginary and convert them separately.
real, imag = a
real = convert_to_tensor(real)
imag = convert_to_tensor(imag)
# Check shape.
if real.shape != imag.shape:
raise ValueError(
"Input `a` should be a tuple of two tensors - real and imaginary."
"Both the real and imaginary parts should have the same shape. "
f"Received: a[0].shape = {real.shape}, a[1].shape = {imag.shape}"
)
# Ensure dtype is float.
if not torch.is_floating_point(real) or not torch.is_floating_point(imag):
raise ValueError(
"At least one tensor in input `a` is not of type float."
f"Received: a={a}."
)

complex_input = torch.complex(real, imag)
return complex_input


def fft(a):
complex_input = _get_complex_tensor_from_tuple(a)
complex_output = torch.fft.fft(complex_input)
return complex_output.real, complex_output.imag


def fft2(a):
complex_input = _get_complex_tensor_from_tuple(a)
complex_output = torch.fft.fft2(complex_input)
return complex_output.real, complex_output.imag
138 changes: 138 additions & 0 deletions keras_core/ops/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,3 +260,141 @@ def qr(x, mode="reduced"):
if any_symbolic_tensors((x,)):
return Qr(mode=mode).symbolic_call(x)
return backend.math.qr(x, mode=mode)


class FFT(Operation):
def compute_output_spec(self, a):
if not isinstance(a, (tuple, list)) or len(a) != 2:
raise ValueError(
"Input `a` should be a tuple of two tensors - real and "
f"imaginary. Received: a={a}"
)

real, imag = a
# Both real and imaginary parts should have the same shape.
if real.shape != imag.shape:
raise ValueError(
"Input `a` should be a tuple of two tensors - real and "
"imaginary. Both the real and imaginary parts should have the "
f"same shape. Received: a[0].shape = {real.shape}, "
f"a[1].shape = {imag.shape}"
)

# We are calculating 1D FFT. Hence, rank >= 1.
if len(real.shape) < 1:
raise ValueError(
f"Input should have rank >= 1. "
f"Received: input.shape = {real.shape}"
)

# The axis along which we are calculating FFT should be fully-defined.
m = real.shape[-1]
if m is None:
raise ValueError(
f"Input should have its {self.axis}th axis fully-defined. "
f"Received: input.shape = {real.shape}"
)

return (
KerasTensor(shape=real.shape, dtype=real.dtype),
KerasTensor(shape=imag.shape, dtype=imag.dtype),
)

def call(self, x):
return backend.math.fft(x)


class FFT2(Operation):
def compute_output_spec(self, a):
if not isinstance(a, (tuple, list)) or len(a) != 2:
raise ValueError(
"Input `a` should be a tuple of two tensors - real and "
f"imaginary. Received: a={a}"
)

real, imag = a
# Both real and imaginary parts should have the same shape.
if real.shape != imag.shape:
raise ValueError(
"Input `a` should be a tuple of two tensors - real and "
"imaginary. Both the real and imaginary parts should have the "
f"same shape. Received: a[0].shape = {real.shape}, "
f"a[1].shape = {imag.shape}"
)
# We are calculating 2D FFT. Hence, rank >= 2.
if len(real.shape) < 2:
raise ValueError(
f"Input should have rank >= 2. "
f"Received: input.shape = {real.shape}"
)

# The axes along which we are calculating FFT should be fully-defined.
m = real.shape[-1]
n = real.shape[-2]
if m is None or n is None:
raise ValueError(
f"Input should have its {self.axes} axes fully-defined. "
f"Received: input.shape = {real.shape}"
)

return (
KerasTensor(shape=real.shape, dtype=real.dtype),
KerasTensor(shape=imag.shape, dtype=imag.dtype),
)

def call(self, x):
return backend.math.fft2(x)


@keras_core_export("keras_core.ops.fft")
def fft(a):
"""Computes the Fast Fourier Transform along last axis of input.
Args:
a: Tuple of the real and imaginary parts of the input tensor. Both
tensors in the tuple should be of floating type.
Returns:
A tuple containing two tensors - the real and imaginary parts of the
output tensor.
Example:
>>> a = (
... keras_core.ops.convert_to_tensor([1., 2.]),
... keras_core.ops.convert_to_tensor([0., 1.]),
... )
>>> fft(x)
(array([ 3., -1.], dtype=float32), array([ 1., -1.], dtype=float32))
"""
if any_symbolic_tensors(a):
return FFT().symbolic_call(a)
return backend.math.fft(a)


@keras_core_export("keras_core.ops.fft2")
def fft2(a):
"""Computes the 2D Fast Fourier Transform along the last two axes of input.
Args:
a: Tuple of the real and imaginary parts of the input tensor. Both
tensors in the tuple should be of floating type.
Returns:
A tuple containing two tensors - the real and imaginary parts of the
output.
Example:
>>> x = (
... keras_core.ops.convert_to_tensor([[1., 2.], [2., 1.]]),
... keras_core.ops.convert_to_tensor([[0., 1.], [1., 0.]]),
... )
>>> fft2(x)
(array([[ 6., 0.],
[ 0., -2.]], dtype=float32), array([[ 2., 0.],
[ 0., -2.]], dtype=float32))
"""
if any_symbolic_tensors(a):
return FFT2().symbolic_call(a)
return backend.math.fft2(a)
58 changes: 58 additions & 0 deletions keras_core/ops/math_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,24 @@ def test_qr(self):
self.assertEqual(q.shape, qref_shape)
self.assertEqual(r.shape, rref_shape)

def test_fft(self):
real = KerasTensor((None, 4, 3), dtype="float32")
imag = KerasTensor((None, 4, 3), dtype="float32")
real_output, imag_output = kmath.fft((real, imag))
ref = np.fft.fft(np.ones((2, 4, 3)))
ref_shape = (None,) + ref.shape[1:]
self.assertEqual(real_output.shape, ref_shape)
self.assertEqual(imag_output.shape, ref_shape)

def test_fft2(self):
real = KerasTensor((None, 4, 3), dtype="float32")
imag = KerasTensor((None, 4, 3), dtype="float32")
real_output, imag_output = kmath.fft2((real, imag))
ref = np.fft.fft2(np.ones((2, 4, 3)))
ref_shape = (None,) + ref.shape[1:]
self.assertEqual(real_output.shape, ref_shape)
self.assertEqual(imag_output.shape, ref_shape)


class MathOpsStaticShapeTest(testing.TestCase):
@pytest.mark.skipif(
Expand Down Expand Up @@ -100,6 +118,22 @@ def test_qr(self):
self.assertEqual(q.shape, qref.shape)
self.assertEqual(r.shape, rref.shape)

def test_fft(self):
real = KerasTensor((2, 4, 3), dtype="float32")
imag = KerasTensor((2, 4, 3), dtype="float32")
real_output, imag_output = kmath.fft((real, imag))
ref = np.fft.fft(np.ones((2, 4, 3)))
self.assertEqual(real_output.shape, ref.shape)
self.assertEqual(imag_output.shape, ref.shape)

def test_fft2(self):
real = KerasTensor((2, 4, 3), dtype="float32")
imag = KerasTensor((2, 4, 3), dtype="float32")
real_output, imag_output = kmath.fft2((real, imag))
ref = np.fft.fft2(np.ones((2, 4, 3)))
self.assertEqual(real_output.shape, ref.shape)
self.assertEqual(imag_output.shape, ref.shape)


class MathOpsCorrectnessTest(testing.TestCase):
@pytest.mark.skipif(
Expand Down Expand Up @@ -226,3 +260,27 @@ def test_qr(self):
qref, rref = np.linalg.qr(x, mode="complete")
self.assertAllClose(qref, q)
self.assertAllClose(rref, r)

def test_fft(self):
real = np.random.random((2, 4, 3))
imag = np.random.random((2, 4, 3))
complex_arr = real + 1j * imag

real_output, imag_output = kmath.fft((real, imag))
ref = np.fft.fft(complex_arr)
real_ref = np.real(ref)
imag_ref = np.imag(ref)
self.assertAllClose(real_ref, real_output)
self.assertAllClose(imag_ref, imag_output)

def test_fft2(self):
real = np.random.random((2, 4, 3))
imag = np.random.random((2, 4, 3))
complex_arr = real + 1j * imag

real_output, imag_output = kmath.fft2((real, imag))
ref = np.fft.fft2(complex_arr)
real_ref = np.real(ref)
imag_ref = np.imag(ref)
self.assertAllClose(real_ref, real_output)
self.assertAllClose(imag_ref, imag_output)

0 comments on commit a19a28e

Please sign in to comment.