Skip to content

Commit

Permalink
Add fuzz testing for chexify.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 534012861
  • Loading branch information
stompchicken authored and ChexDev committed May 22, 2023
1 parent 9a7b5ef commit 50ed7f4
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 4 deletions.
13 changes: 9 additions & 4 deletions chex/_src/asserts_chexify.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import dataclasses
import functools
import re
import threading
from typing import Any, Callable, FrozenSet

from absl import logging
Expand Down Expand Up @@ -178,6 +179,8 @@ def logp1_abs_safe(x: chex.Array) -> chex.Array:
thread_pool = futures.ThreadPoolExecutor(1, f'async_chex_{func_name}')
# A deque for futures.
async_check_futures = collections.deque()
# Protect the futures from concurrent access.
async_check_futures_lock = threading.Lock()

# Checkification.
checkified_fn = checkify.checkify(fn, errors=errors)
Expand All @@ -191,8 +194,9 @@ def _chexified_fn(*args, **kwargs):

if async_check:
# Check completed calls.
while async_check_futures and async_check_futures[0].done():
_check_error(async_check_futures.popleft().result(async_timeout))
with async_check_futures_lock:
while async_check_futures and async_check_futures[0].done():
_check_error(async_check_futures.popleft().result(async_timeout))

# Run the checkified function.
_ai.CHEXIFY_STORAGE.level += 1
Expand All @@ -214,8 +218,9 @@ def _chexified_fn(*args, **kwargs):

def _wait_checks():
if async_check:
while async_check_futures:
_check_error(async_check_futures.popleft().result(async_timeout))
with async_check_futures_lock:
while async_check_futures:
_check_error(async_check_futures.popleft().result(async_timeout))

# Add a barrier callback to the global storage.
_ai.CHEXIFY_STORAGE.wait_fns.append(_wait_checks)
Expand Down
65 changes: 65 additions & 0 deletions chex/_src/asserts_chexify_fuzz_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Fuzz test for `asserts_chexify.py`."""

import concurrent.futures
import random
import time

from absl.testing import absltest
from chex._src import asserts
from chex._src import asserts_chexify
from chex._src import variants
import jax
import jax.numpy as jnp


class AssertsChexifyFuzzTest(variants.TestCase):
"""Fuzz test for thread safety of chexify."""

def test_thread_safety(self):

def divide_by_zero():
result = jnp.ones(shape=()) / jnp.zeros(shape=())
asserts.assert_tree_all_finite(result)
return result

def chexified_divide_by_zero():
fn = asserts_chexify.chexify(divide_by_zero, async_check=True)
fn()
# Introduce random delay between the two calls, otherwise we will not
# get interleaving of the two operations between threads because they
# happen too quickly.
time.sleep(random.uniform(0.01, 0.02))
asserts_chexify.block_until_chexify_assertions_complete()

with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor:
futures = []
for _ in range(1000):
future = executor.submit(chexified_divide_by_zero)
futures.append(future)

for future in concurrent.futures.as_completed(futures):
try:
future.result()
except AssertionError:
pass

asserts_chexify.block_until_chexify_assertions_complete()


if __name__ == '__main__':
jax.config.update('jax_numpy_rank_promotion', 'raise')
absltest.main()

0 comments on commit 50ed7f4

Please sign in to comment.