Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FFT Ops #480

Merged
merged 11 commits into from
Jul 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions keras_core/backend/jax/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,43 @@ def qr(x, mode="reduced"):
f"Received: mode={mode}"
)
return jax.numpy.linalg.qr(x, mode=mode)


def _get_complex_tensor_from_tuple(a):
if not isinstance(a, (tuple, list)) or len(a) != 2:
raise ValueError(
"Input `a` should be a tuple of two tensors - real and imaginary."
f"Received: a={a}"
)
# `convert_to_tensor` does not support passing complex tensors. We separate
# the input out into real and imaginary and convert them separately.
real, imag = a
# Check shapes.
if real.shape != imag.shape:
raise ValueError(
"Input `a` should be a tuple of two tensors - real and imaginary."
"Both the real and imaginary parts should have the same shape. "
f"Received: a[0].shape = {real.shape}, a[1].shape = {imag.shape}"
)
# Ensure dtype is float.
if not jnp.issubdtype(real.dtype, jnp.floating) or not jnp.issubdtype(
imag.dtype, jnp.floating
):
raise ValueError(
"At least one tensor in input `a` is not of type float."
f"Received: a={a}."
)
complex_input = jax.lax.complex(real, imag)
return complex_input


def fft(a):
complex_input = _get_complex_tensor_from_tuple(a)
complex_output = jax.numpy.fft.fft(complex_input)
return jax.numpy.real(complex_output), jax.numpy.imag(complex_output)


def fft2(a):
complex_input = _get_complex_tensor_from_tuple(a)
complex_output = jax.numpy.fft.fft2(complex_input)
return jax.numpy.real(complex_output), jax.numpy.imag(complex_output)
13 changes: 13 additions & 0 deletions keras_core/backend/numpy/math.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import numpy as np

from keras_core.backend.jax.math import fft as jax_fft
from keras_core.backend.jax.math import fft2 as jax_fft2


def segment_sum(data, segment_ids, num_segments=None, sorted=False):
if num_segments is None:
Expand Down Expand Up @@ -74,3 +77,13 @@ def qr(x, mode="reduced"):
f"Received: mode={mode}"
)
return np.linalg.qr(x, mode=mode)


def fft(a):
real, imag = jax_fft(a)
return np.array(real), np.array(imag)


def fft2(a):
real, imag = jax_fft2(a)
return np.array(real), np.array(imag)
42 changes: 42 additions & 0 deletions keras_core/backend/tensorflow/math.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import tensorflow as tf

from keras_core.backend.tensorflow.core import convert_to_tensor


def segment_sum(data, segment_ids, num_segments=None, sorted=False):
if sorted:
Expand Down Expand Up @@ -33,3 +35,43 @@ def qr(x, mode="reduced"):
if mode == "reduced":
return tf.linalg.qr(x)
return tf.linalg.qr(x, full_matrices=True)


def _get_complex_tensor_from_tuple(a):
abheesht17 marked this conversation as resolved.
Show resolved Hide resolved
if not isinstance(a, (tuple, list)) or len(a) != 2:
raise ValueError(
"Input `a` should be a tuple of two tensors - real and imaginary."
f"Received: a={a}"
)
# `convert_to_tensor` does not support passing complex tensors. We separate
# the input out into real and imaginary and convert them separately.
real, imag = a
real = convert_to_tensor(real)
imag = convert_to_tensor(imag)
# Check shapes.
if real.shape != imag.shape:
raise ValueError(
"Input `a` should be a tuple of two tensors - real and imaginary."
"Both the real and imaginary parts should have the same shape. "
f"Received: a[0].shape = {real.shape}, a[1].shape = {imag.shape}"
)
# Ensure dtype is float.
if not real.dtype.is_floating or not imag.dtype.is_floating:
raise ValueError(
"At least one tensor in input `a` is not of type float."
f"Received: a={a}."
)
complex_input = tf.dtypes.complex(real, imag)
return complex_input


def fft(a):
complex_input = _get_complex_tensor_from_tuple(a)
complex_output = tf.signal.fft(complex_input)
return tf.math.real(complex_output), tf.math.imag(complex_output)


def fft2(a):
complex_input = _get_complex_tensor_from_tuple(a)
complex_output = tf.signal.fft2d(complex_input)
return tf.math.real(complex_output), tf.math.imag(complex_output)
41 changes: 41 additions & 0 deletions keras_core/backend/torch/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,44 @@ def qr(x, mode="reduced"):
)
x = convert_to_tensor(x)
return torch.linalg.qr(x, mode=mode)


def _get_complex_tensor_from_tuple(a):
if not isinstance(a, (tuple, list)) or len(a) != 2:
raise ValueError(
"Input `a` should be a tuple of two tensors - real and imaginary."
f"Received: a={a}"
)
# `convert_to_tensor` does not support passing complex tensors. We separate
# the input out into real and imaginary and convert them separately.
real, imag = a
real = convert_to_tensor(real)
imag = convert_to_tensor(imag)
# Check shape.
if real.shape != imag.shape:
raise ValueError(
"Input `a` should be a tuple of two tensors - real and imaginary."
"Both the real and imaginary parts should have the same shape. "
f"Received: a[0].shape = {real.shape}, a[1].shape = {imag.shape}"
)
# Ensure dtype is float.
if not torch.is_floating_point(real) or not torch.is_floating_point(imag):
raise ValueError(
"At least one tensor in input `a` is not of type float."
f"Received: a={a}."
)

complex_input = torch.complex(real, imag)
return complex_input


def fft(a):
complex_input = _get_complex_tensor_from_tuple(a)
complex_output = torch.fft.fft(complex_input)
return complex_output.real, complex_output.imag


def fft2(a):
complex_input = _get_complex_tensor_from_tuple(a)
complex_output = torch.fft.fft2(complex_input)
return complex_output.real, complex_output.imag
138 changes: 138 additions & 0 deletions keras_core/ops/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,3 +260,141 @@ def qr(x, mode="reduced"):
if any_symbolic_tensors((x,)):
return Qr(mode=mode).symbolic_call(x)
return backend.math.qr(x, mode=mode)


class FFT(Operation):
def compute_output_spec(self, a):
if not isinstance(a, (tuple, list)) or len(a) != 2:
raise ValueError(
"Input `a` should be a tuple of two tensors - real and "
f"imaginary. Received: a={a}"
)

real, imag = a
# Both real and imaginary parts should have the same shape.
if real.shape != imag.shape:
raise ValueError(
"Input `a` should be a tuple of two tensors - real and "
"imaginary. Both the real and imaginary parts should have the "
f"same shape. Received: a[0].shape = {real.shape}, "
f"a[1].shape = {imag.shape}"
)

# We are calculating 1D FFT. Hence, rank >= 1.
if len(real.shape) < 1:
raise ValueError(
f"Input should have rank >= 1. "
f"Received: input.shape = {real.shape}"
)

# The axis along which we are calculating FFT should be fully-defined.
m = real.shape[-1]
if m is None:
raise ValueError(
f"Input should have its {self.axis}th axis fully-defined. "
f"Received: input.shape = {real.shape}"
)

return (
KerasTensor(shape=real.shape, dtype=real.dtype),
KerasTensor(shape=imag.shape, dtype=imag.dtype),
)

def call(self, x):
return backend.math.fft(x)


class FFT2(Operation):
def compute_output_spec(self, a):
if not isinstance(a, (tuple, list)) or len(a) != 2:
raise ValueError(
"Input `a` should be a tuple of two tensors - real and "
f"imaginary. Received: a={a}"
)

real, imag = a
# Both real and imaginary parts should have the same shape.
if real.shape != imag.shape:
raise ValueError(
"Input `a` should be a tuple of two tensors - real and "
"imaginary. Both the real and imaginary parts should have the "
f"same shape. Received: a[0].shape = {real.shape}, "
f"a[1].shape = {imag.shape}"
)
# We are calculating 2D FFT. Hence, rank >= 2.
if len(real.shape) < 2:
raise ValueError(
f"Input should have rank >= 2. "
f"Received: input.shape = {real.shape}"
)

# The axes along which we are calculating FFT should be fully-defined.
m = real.shape[-1]
n = real.shape[-2]
if m is None or n is None:
raise ValueError(
f"Input should have its {self.axes} axes fully-defined. "
f"Received: input.shape = {real.shape}"
)

return (
KerasTensor(shape=real.shape, dtype=real.dtype),
KerasTensor(shape=imag.shape, dtype=imag.dtype),
)

def call(self, x):
return backend.math.fft2(x)


@keras_core_export("keras_core.ops.fft")
def fft(a):
"""Computes the Fast Fourier Transform along last axis of input.

Args:
a: Tuple of the real and imaginary parts of the input tensor. Both
tensors in the tuple should be of floating type.

Returns:
A tuple containing two tensors - the real and imaginary parts of the
output tensor.

Example:

>>> a = (
... keras_core.ops.convert_to_tensor([1., 2.]),
... keras_core.ops.convert_to_tensor([0., 1.]),
... )
>>> fft(x)
(array([ 3., -1.], dtype=float32), array([ 1., -1.], dtype=float32))
"""
if any_symbolic_tensors(a):
return FFT().symbolic_call(a)
return backend.math.fft(a)


@keras_core_export("keras_core.ops.fft2")
def fft2(a):
"""Computes the 2D Fast Fourier Transform along the last two axes of input.

Args:
a: Tuple of the real and imaginary parts of the input tensor. Both
tensors in the tuple should be of floating type.

Returns:
A tuple containing two tensors - the real and imaginary parts of the
output.

Example:

>>> x = (
... keras_core.ops.convert_to_tensor([[1., 2.], [2., 1.]]),
... keras_core.ops.convert_to_tensor([[0., 1.], [1., 0.]]),
... )
>>> fft2(x)
(array([[ 6., 0.],
[ 0., -2.]], dtype=float32), array([[ 2., 0.],
[ 0., -2.]], dtype=float32))
"""
if any_symbolic_tensors(a):
return FFT2().symbolic_call(a)
return backend.math.fft2(a)
58 changes: 58 additions & 0 deletions keras_core/ops/math_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,24 @@ def test_qr(self):
self.assertEqual(q.shape, qref_shape)
self.assertEqual(r.shape, rref_shape)

def test_fft(self):
real = KerasTensor((None, 4, 3), dtype="float32")
imag = KerasTensor((None, 4, 3), dtype="float32")
real_output, imag_output = kmath.fft((real, imag))
ref = np.fft.fft(np.ones((2, 4, 3)))
ref_shape = (None,) + ref.shape[1:]
self.assertEqual(real_output.shape, ref_shape)
self.assertEqual(imag_output.shape, ref_shape)

def test_fft2(self):
real = KerasTensor((None, 4, 3), dtype="float32")
imag = KerasTensor((None, 4, 3), dtype="float32")
real_output, imag_output = kmath.fft2((real, imag))
ref = np.fft.fft2(np.ones((2, 4, 3)))
ref_shape = (None,) + ref.shape[1:]
self.assertEqual(real_output.shape, ref_shape)
self.assertEqual(imag_output.shape, ref_shape)


class MathOpsStaticShapeTest(testing.TestCase):
@pytest.mark.skipif(
Expand Down Expand Up @@ -100,6 +118,22 @@ def test_qr(self):
self.assertEqual(q.shape, qref.shape)
self.assertEqual(r.shape, rref.shape)

def test_fft(self):
real = KerasTensor((2, 4, 3), dtype="float32")
imag = KerasTensor((2, 4, 3), dtype="float32")
real_output, imag_output = kmath.fft((real, imag))
ref = np.fft.fft(np.ones((2, 4, 3)))
self.assertEqual(real_output.shape, ref.shape)
self.assertEqual(imag_output.shape, ref.shape)

def test_fft2(self):
real = KerasTensor((2, 4, 3), dtype="float32")
imag = KerasTensor((2, 4, 3), dtype="float32")
real_output, imag_output = kmath.fft2((real, imag))
ref = np.fft.fft2(np.ones((2, 4, 3)))
self.assertEqual(real_output.shape, ref.shape)
self.assertEqual(imag_output.shape, ref.shape)


class MathOpsCorrectnessTest(testing.TestCase):
@pytest.mark.skipif(
Expand Down Expand Up @@ -226,3 +260,27 @@ def test_qr(self):
qref, rref = np.linalg.qr(x, mode="complete")
self.assertAllClose(qref, q)
self.assertAllClose(rref, r)

def test_fft(self):
real = np.random.random((2, 4, 3))
imag = np.random.random((2, 4, 3))
complex_arr = real + 1j * imag

real_output, imag_output = kmath.fft((real, imag))
ref = np.fft.fft(complex_arr)
real_ref = np.real(ref)
imag_ref = np.imag(ref)
self.assertAllClose(real_ref, real_output)
self.assertAllClose(imag_ref, imag_output)

def test_fft2(self):
real = np.random.random((2, 4, 3))
imag = np.random.random((2, 4, 3))
complex_arr = real + 1j * imag

real_output, imag_output = kmath.fft2((real, imag))
ref = np.fft.fft2(complex_arr)
real_ref = np.real(ref)
imag_ref = np.imag(ref)
self.assertAllClose(real_ref, real_output)
self.assertAllClose(imag_ref, imag_output)