Skip to content

Commit

Permalink
Fix finite difference in tests (#5)
Browse files Browse the repository at this point in the history
Fix finite difference in tests
  • Loading branch information
alexeytochin authored Jun 20, 2024
1 parent f4e367d commit 3b7add7
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 38 deletions.
3 changes: 2 additions & 1 deletion tests/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@
from tqdm import tqdm
from tabulate import tabulate

from tests.common import generate_ctc_loss_inputs, tf_ctc_loss
from tf_seq2seq_losses.classic_ctc_loss import classic_ctc_loss
from tf_seq2seq_losses.simplified_ctc_loss import simplified_ctc_loss

from tests.common import generate_ctc_loss_inputs, tf_ctc_loss


class TestBenchmarkCtcLosses(unittest.TestCase):
"""Benchmark for CTC losses implementations."""
Expand Down
7 changes: 3 additions & 4 deletions tests/finite_difference.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def func_(x_: tf.Tensor) -> tf.Tensor:
def _finite_difference_batch_jacobian(
func, x, epsilon: Union[float, tf.Tensor]
) -> tf.Tensor:
"""Calculate final difference Jacobian approximation
"""Calculate finite difference Jacobian approximation
Args:
func: shape = [batch_size, dim_x] -> [batch_size, dim_y]
Expand All @@ -104,9 +104,8 @@ def _finite_difference_batch_jacobian(
# shape = [dim_x, batch_size, dim_x]
y0 = func(x)
# shape = [batch_size, dim_y]
dy_transposed = (
tf.vectorized_map(fn=func, elems=pre_x1) - tf.expand_dims(y0, 0)
) / epsilon

dy_transposed = (tf.map_fn(fn=func, elems=pre_x1) - tf.expand_dims(y0, 0)) / epsilon
# shape = [dim_x, batch_size, dim_y]
dy = tf.transpose(dy_transposed, perm=[1, 2, 0])

Expand Down
23 changes: 10 additions & 13 deletions tests/test_classic_ctc_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import unittest

import numpy as np
import tensorflow as tf

from tests.common import generate_ctc_loss_inputs
from tests.test_ctc_losses import TestCtcLoss
from tests.finite_difference import finite_difference_batch_jacobian
from tf_seq2seq_losses.base_loss import ctc_loss_from_logproba
from tf_seq2seq_losses.classic_ctc_loss import ClassicCtcLossData, classic_ctc_loss
from tf_seq2seq_losses.tools import logit_to_logproba

from tests.common import generate_ctc_loss_inputs
from tests.test_ctc_losses import TestCtcLoss
from tests.finite_difference import finite_difference_batch_jacobian


class TestClassicCtcLoss(TestCtcLoss):
"""Tests for the classic CTC loss."""
Expand Down Expand Up @@ -393,7 +392,6 @@ def test_compare_gradient_with_tf_implementation(self):
tf_version_gradient, classic_version_gradient, 4
)

@unittest.skip("fix_finite_difference")
def test_gradient_vs_finite_difference(self):
"""Test for the comparison of the gradient with the finite difference."""
blank_index = 0
Expand Down Expand Up @@ -478,29 +476,28 @@ def test_second_derivative_shape(self):
list(hessian_analytic.shape),
)

@unittest.skip("fix_finite_difference")
def test_hessian_vs_finite_difference(self):
"""Test for the comparison of the Hessian with the finite difference."""
input_dict = generate_ctc_loss_inputs(
max_logit_length=4, batch_size=2, random_seed=0, num_tokens=2, blank_index=0
)
logits = input_dict["logits"]

def gradient_fn(logits):
with tf.GradientTape() as tape:
tape.watch([logits])
def gradient_fn(logits_):
with tf.GradientTape() as tape_:
tape_.watch([logits_])
loss = tf.reduce_sum(
classic_ctc_loss(
labels=input_dict["labels"],
logits=logits,
logits=logits_,
label_length=input_dict["label_length"],
logit_length=input_dict["logit_length"],
blank_index=0,
)
)
gradient = tape.gradient(loss, sources=logits)
gradient_ = tape_.gradient(loss, sources=logits_)
# shape = [batch_size, logit_length, num_tokens]
return gradient
return gradient_

hessian_numerical = finite_difference_batch_jacobian(
func=gradient_fn, x=logits, epsilon=1e-4
Expand Down
7 changes: 4 additions & 3 deletions tests/test_hessian.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,17 @@

import tensorflow as tf

from tests.common import generate_ctc_loss_inputs
from tests.test_ctc_losses import TestCtcLoss
from tests.finite_difference import finite_difference_batch_jacobian
from tf_seq2seq_losses import classic_ctc_loss
from tf_seq2seq_losses.base_loss import ctc_loss_from_logproba
from tf_seq2seq_losses.simplified_ctc_loss import (
SimplifiedCtcLossData,
simplified_ctc_loss,
)

from tf_seq2seq_losses.tools import logit_to_logproba
from tests.common import generate_ctc_loss_inputs
from tests.test_ctc_losses import TestCtcLoss
from tests.finite_difference import finite_difference_batch_jacobian


class TestSimplifiedCtcLoss(TestCtcLoss):
Expand Down
29 changes: 13 additions & 16 deletions tests/test_simplified_ctc_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import unittest

import numpy as np
import tensorflow as tf

from tests.common import generate_ctc_loss_inputs
from tests.test_ctc_losses import TestCtcLoss
from tests.finite_difference import finite_difference_batch_jacobian
from tf_seq2seq_losses.simplified_ctc_loss import (
simplified_ctc_loss,
SimplifiedCtcLossData,
)
from tf_seq2seq_losses.tools import logit_to_logproba

from tests.common import generate_ctc_loss_inputs
from tests.test_ctc_losses import TestCtcLoss
from tests.finite_difference import finite_difference_batch_jacobian


class TestSimplifiedCtcLoss(TestCtcLoss):
"""Tests for the simplified CTC loss."""
Expand Down Expand Up @@ -258,7 +257,6 @@ def test_length_two(self):
places=6,
)

@unittest.skip("fix_finite_difference")
def test_gradient_with_finite_difference(self):
"""Test for the gradient with finite difference."""
blank_index = 0
Expand Down Expand Up @@ -290,7 +288,6 @@ def loss_fn(logits_):
tape.watch([logits])
loss = tf.reduce_sum(loss_fn(logits))
gradient_analytic = tape.gradient(loss, sources=logits)

self.assert_tensors_almost_equal(gradient_numerical, gradient_analytic, 1)

def test_autograph(self):
Expand Down Expand Up @@ -336,9 +333,9 @@ def func():
output = simplified_ctc_loss(
labels, logits, label_length, logit_length, 0
)
loss = tf.reduce_mean(output)
gradient = tape.gradient(loss, sources=logits)
return loss, gradient
loss_ = tf.reduce_mean(output)
gradient_ = tape.gradient(loss_, sources=logits)
return loss_, gradient_

loss, gradient = func()

Expand All @@ -356,14 +353,14 @@ def test_zero_batch_size(self):
def func():
with tf.GradientTape() as tape:
tape.watch([logits])
loss_samplewise = simplified_ctc_loss(
loss_samplewise_ = simplified_ctc_loss(
labels, logits, label_length, logit_length, 0
)
loss = tf.reduce_sum(loss_samplewise)
gradient = tape.gradient(loss, sources=logits)
return loss_samplewise, gradient
loss = tf.reduce_sum(loss_samplewise_)
gradient_ = tape.gradient(loss, sources=logits)
return loss_samplewise_, gradient_

loss_samplewise, gradient = func()
loss_samplewise_, gradient = func()

self.assertEqual([0], loss_samplewise.shape)
self.assertEqual([0], loss_samplewise_.shape)
self.assertEqual([0, 4, 3], gradient.shape)
3 changes: 2 additions & 1 deletion tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import tensorflow as tf
import numpy as np

from tests.finite_difference import finite_difference_batch_jacobian
from tf_seq2seq_losses.tools import (
logsumexp,
insert_zeros,
Expand All @@ -29,6 +28,8 @@
expand_many_dims,
)

from tests.finite_difference import finite_difference_batch_jacobian


class TestLogSumExp(unittest.TestCase):
"""Tests for the logsumexp function."""
Expand Down

0 comments on commit 3b7add7

Please sign in to comment.