diff --git a/okio/src/jvmMain/kotlin/okio/Timeout.kt b/okio/src/jvmMain/kotlin/okio/Timeout.kt index 5fb88ce47e..576591ad54 100644 --- a/okio/src/jvmMain/kotlin/okio/Timeout.kt +++ b/okio/src/jvmMain/kotlin/okio/Timeout.kt @@ -193,10 +193,19 @@ actual open class Timeout { if (waitNanos <= 0) throw InterruptedIOException("timeout") + val cancelMarkBefore = cancelMark + // Attempt to wait that long. This will return early if the monitor is notified. val nanosRemaining = condition.awaitNanos(waitNanos) - if (nanosRemaining <= 0) throw InterruptedIOException("timeout") + // If there's time remaining, we probably got the call we were waiting for. + if (nanosRemaining > 0) return + + // Return without throwing if this timeout was canceled while we were waiting. Note that this + // return is a 'spurious wakeup' because Condition.signal() was not called. + if (cancelMark !== cancelMarkBefore) return + + throw InterruptedIOException("timeout") } catch (e: InterruptedException) { Thread.currentThread().interrupt() // Retain interrupted status. throw InterruptedIOException("interrupted") diff --git a/okio/src/jvmTest/kotlin/okio/AwaitSignalTest.kt b/okio/src/jvmTest/kotlin/okio/AwaitSignalTest.kt index d8f150cb87..941791a785 100644 --- a/okio/src/jvmTest/kotlin/okio/AwaitSignalTest.kt +++ b/okio/src/jvmTest/kotlin/okio/AwaitSignalTest.kt @@ -21,7 +21,9 @@ import java.util.concurrent.locks.Condition import java.util.concurrent.locks.ReentrantLock import okio.TestUtil.assumeNotWindows import org.junit.After -import org.junit.Assert +import org.junit.Assert.assertEquals +import org.junit.Assert.assertTrue +import org.junit.Assert.fail import org.junit.Test class AwaitSignalTest { @@ -57,9 +59,9 @@ class AwaitSignalTest { val start = now() try { timeout.awaitSignal(condition) - Assert.fail() + fail() } catch (expected: InterruptedIOException) { - Assert.assertEquals("timeout", expected.message) + assertEquals("timeout", expected.message) } assertElapsed(1000.0, start) } @@ -72,9 +74,9 @@ class AwaitSignalTest { val start = now() try { timeout.awaitSignal(condition) - Assert.fail() + fail() } catch (expected: InterruptedIOException) { - Assert.assertEquals("timeout", expected.message) + assertEquals("timeout", expected.message) } assertElapsed(1000.0, start) } @@ -88,9 +90,9 @@ class AwaitSignalTest { val start = now() try { timeout.awaitSignal(condition) - Assert.fail() + fail() } catch (expected: InterruptedIOException) { - Assert.assertEquals("timeout", expected.message) + assertEquals("timeout", expected.message) } assertElapsed(1000.0, start) } @@ -104,9 +106,9 @@ class AwaitSignalTest { val start = now() try { timeout.awaitSignal(condition) - Assert.fail() + fail() } catch (expected: InterruptedIOException) { - Assert.assertEquals("timeout", expected.message) + assertEquals("timeout", expected.message) } assertElapsed(1000.0, start) } @@ -119,9 +121,9 @@ class AwaitSignalTest { val start = now() try { timeout.awaitSignal(condition) - Assert.fail() + fail() } catch (expected: InterruptedIOException) { - Assert.assertEquals("timeout", expected.message) + assertEquals("timeout", expected.message) } assertElapsed(0.0, start) } @@ -134,10 +136,10 @@ class AwaitSignalTest { Thread.currentThread().interrupt() try { timeout.awaitSignal(condition) - Assert.fail() + fail() } catch (expected: InterruptedIOException) { - Assert.assertEquals("interrupted", expected.message) - Assert.assertTrue(Thread.interrupted()) + assertEquals("interrupted", expected.message) + assertTrue(Thread.interrupted()) } assertElapsed(0.0, start) } @@ -149,13 +151,60 @@ class AwaitSignalTest { Thread.currentThread().interrupt() try { timeout.throwIfReached() - Assert.fail() + fail() } catch (expected: InterruptedIOException) { - Assert.assertEquals("interrupted", expected.message) - Assert.assertTrue(Thread.interrupted()) + assertEquals("interrupted", expected.message) + assertTrue(Thread.interrupted()) } } + @Test + fun cancelBeforeWaitDoesNothing() { + val timeout = Timeout() + timeout.timeout(1000, TimeUnit.MILLISECONDS) + timeout.cancel() + val start = now() + try { + lock.withLock { + timeout.awaitSignal(condition) + } + fail() + } catch (expected: InterruptedIOException) { + assertEquals("timeout", expected.message) + } + assertElapsed(1000.0, start) + } + + @Test + fun canceledTimeoutDoesNotThrowWhenNotNotifiedOnTime() { + assumeNotWindows() + val timeout = Timeout() + timeout.timeout(1000, TimeUnit.MILLISECONDS) + timeout.cancelLater(500) + + val start = now() + lock.withLock { + timeout.awaitSignal(condition) // Returns early but doesn't throw. + } + assertElapsed(1000.0, start) + } + + @Test + @Synchronized + fun multipleCancelsAreIdempotent() { + val timeout = Timeout() + timeout.timeout(1000, TimeUnit.MILLISECONDS) + timeout.cancelLater(250) + timeout.cancelLater(500) + timeout.cancelLater(750) + + val start = now() + lock.withLock { + timeout.awaitSignal(condition) // Returns early but doesn't throw. + } + assertElapsed(1000.0, start) + } + /** Returns the nanotime in milliseconds as a double for measuring timeouts. */ private fun now(): Double { return System.nanoTime() / 1000000.0 @@ -166,6 +215,14 @@ class AwaitSignalTest { * -50..+450 milliseconds. */ private fun assertElapsed(duration: Double, start: Double) { - Assert.assertEquals(duration, now() - start - 200.0, 250.0) + assertEquals(duration, now() - start - 200.0, 250.0) + } + + private fun Timeout.cancelLater(delay: Long) { + executorService.schedule( + { cancel() }, + delay, + TimeUnit.MILLISECONDS, + ) } }