Skip to content

Commit

Permalink
Implement ForwardingTimeout.cancel() (#1395)
Browse files Browse the repository at this point in the history
* Implement ForwardingTimeout.cancel()

This needs lock-checking functions to be forwarded also,
as the timeout now holds more state than before.

* Functions are now open
  • Loading branch information
squarejesse authored Dec 15, 2023
1 parent bf29a91 commit d2d0b57
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 42 deletions.
7 changes: 5 additions & 2 deletions okio/api/okio.api
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,8 @@ public abstract class okio/ForwardingSource : okio/Source {

public class okio/ForwardingTimeout : okio/Timeout {
public fun <init> (Lokio/Timeout;)V
public fun awaitSignal (Ljava/util/concurrent/locks/Condition;)V
public fun cancel ()V
public fun clearDeadline ()Lokio/Timeout;
public fun clearTimeout ()Lokio/Timeout;
public fun deadlineNanoTime ()J
Expand All @@ -558,6 +560,7 @@ public class okio/ForwardingTimeout : okio/Timeout {
public fun throwIfReached ()V
public fun timeout (JLjava/util/concurrent/TimeUnit;)Lokio/Timeout;
public fun timeoutNanos ()J
public fun waitUntilNotified (Ljava/lang/Object;)V
}

public final class okio/GzipSink : okio/Sink {
Expand Down Expand Up @@ -771,7 +774,7 @@ public class okio/Timeout {
public static final field Companion Lokio/Timeout$Companion;
public static final field NONE Lokio/Timeout;
public fun <init> ()V
public final fun awaitSignal (Ljava/util/concurrent/locks/Condition;)V
public fun awaitSignal (Ljava/util/concurrent/locks/Condition;)V
public fun cancel ()V
public fun clearDeadline ()Lokio/Timeout;
public fun clearTimeout ()Lokio/Timeout;
Expand All @@ -783,7 +786,7 @@ public class okio/Timeout {
public fun throwIfReached ()V
public fun timeout (JLjava/util/concurrent/TimeUnit;)Lokio/Timeout;
public fun timeoutNanos ()J
public final fun waitUntilNotified (Ljava/lang/Object;)V
public fun waitUntilNotified (Ljava/lang/Object;)V
}

public final class okio/Timeout$Companion {
Expand Down
7 changes: 7 additions & 0 deletions okio/src/jvmMain/kotlin/okio/ForwardingTimeout.kt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package okio

import java.io.IOException
import java.util.concurrent.TimeUnit
import java.util.concurrent.locks.Condition

/** A [Timeout] which forwards calls to another. Useful for subclassing. */
open class ForwardingTimeout(
Expand Down Expand Up @@ -49,4 +50,10 @@ open class ForwardingTimeout(

@Throws(IOException::class)
override fun throwIfReached() = delegate.throwIfReached()

override fun cancel() = delegate.cancel()

override fun awaitSignal(condition: Condition) = delegate.awaitSignal(condition)

override fun waitUntilNotified(monitor: Any) = delegate.waitUntilNotified(monitor)
}
4 changes: 2 additions & 2 deletions okio/src/jvmMain/kotlin/okio/Timeout.kt
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ actual open class Timeout {
* ```
*/
@Throws(InterruptedIOException::class)
fun awaitSignal(condition: Condition) {
open fun awaitSignal(condition: Condition) {
try {
val hasDeadline = hasDeadline()
val timeoutNanos = timeoutNanos()
Expand Down Expand Up @@ -248,7 +248,7 @@ actual open class Timeout {
* ```
*/
@Throws(InterruptedIOException::class)
fun waitUntilNotified(monitor: Any) {
open fun waitUntilNotified(monitor: Any) {
try {
val hasDeadline = hasDeadline()
val timeoutNanos = timeoutNanos()
Expand Down
46 changes: 21 additions & 25 deletions okio/src/jvmTest/kotlin/okio/AwaitSignalTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,15 @@ import org.junit.Assert.assertEquals
import org.junit.Assert.assertTrue
import org.junit.Assert.fail
import org.junit.Test

class AwaitSignalTest {
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
import org.junit.runners.Parameterized.Parameters

@RunWith(Parameterized::class)
class AwaitSignalTest(
factory: TimeoutFactory,
) {
private val timeout = factory.newTimeout()
val executorService = TestingExecutors.newScheduledExecutorService(0)

val lock: ReentrantLock = ReentrantLock()
Expand All @@ -39,7 +46,6 @@ class AwaitSignalTest {

@Test
fun signaled() = lock.withLock {
val timeout = Timeout()
timeout.timeout(5000, TimeUnit.MILLISECONDS)
val start = now()
executorService.schedule(
Expand All @@ -54,7 +60,6 @@ class AwaitSignalTest {
@Test
fun timeout() = lock.withLock {
assumeNotWindows()
val timeout = Timeout()
timeout.timeout(1000, TimeUnit.MILLISECONDS)
val start = now()
try {
Expand All @@ -69,7 +74,6 @@ class AwaitSignalTest {
@Test
fun deadline() = lock.withLock {
assumeNotWindows()
val timeout = Timeout()
timeout.deadline(1000, TimeUnit.MILLISECONDS)
val start = now()
try {
Expand All @@ -84,7 +88,6 @@ class AwaitSignalTest {
@Test
fun deadlineBeforeTimeout() = lock.withLock {
assumeNotWindows()
val timeout = Timeout()
timeout.timeout(5000, TimeUnit.MILLISECONDS)
timeout.deadline(1000, TimeUnit.MILLISECONDS)
val start = now()
Expand All @@ -100,7 +103,6 @@ class AwaitSignalTest {
@Test
fun timeoutBeforeDeadline() = lock.withLock {
assumeNotWindows()
val timeout = Timeout()
timeout.timeout(1000, TimeUnit.MILLISECONDS)
timeout.deadline(5000, TimeUnit.MILLISECONDS)
val start = now()
Expand All @@ -116,7 +118,6 @@ class AwaitSignalTest {
@Test
fun deadlineAlreadyReached() = lock.withLock {
assumeNotWindows()
val timeout = Timeout()
timeout.deadlineNanoTime(System.nanoTime())
val start = now()
try {
Expand All @@ -131,7 +132,6 @@ class AwaitSignalTest {
@Test
fun threadInterrupted() = lock.withLock {
assumeNotWindows()
val timeout = Timeout()
val start = now()
Thread.currentThread().interrupt()
try {
Expand All @@ -147,7 +147,6 @@ class AwaitSignalTest {
@Test
fun threadInterruptedOnThrowIfReached() = lock.withLock {
assumeNotWindows()
val timeout = Timeout()
Thread.currentThread().interrupt()
try {
timeout.throwIfReached()
Expand All @@ -159,15 +158,12 @@ class AwaitSignalTest {
}

@Test
fun cancelBeforeWaitDoesNothing() {
val timeout = Timeout()
fun cancelBeforeWaitDoesNothing() = lock.withLock {
timeout.timeout(1000, TimeUnit.MILLISECONDS)
timeout.cancel()
val start = now()
try {
lock.withLock {
timeout.awaitSignal(condition)
}
timeout.awaitSignal(condition)
fail()
} catch (expected: InterruptedIOException) {
assertEquals("timeout", expected.message)
Expand All @@ -176,32 +172,26 @@ class AwaitSignalTest {
}

@Test
fun canceledTimeoutDoesNotThrowWhenNotNotifiedOnTime() {
fun canceledTimeoutDoesNotThrowWhenNotNotifiedOnTime() = lock.withLock {
assumeNotWindows()
val timeout = Timeout()
timeout.timeout(1000, TimeUnit.MILLISECONDS)
timeout.cancelLater(500)

val start = now()
lock.withLock {
timeout.awaitSignal(condition) // Returns early but doesn't throw.
}
timeout.awaitSignal(condition) // Returns early but doesn't throw.
assertElapsed(1000.0, start)
}

@Test
@Synchronized
fun multipleCancelsAreIdempotent() {
val timeout = Timeout()
fun multipleCancelsAreIdempotent() = lock.withLock {
timeout.timeout(1000, TimeUnit.MILLISECONDS)
timeout.cancelLater(250)
timeout.cancelLater(500)
timeout.cancelLater(750)

val start = now()
lock.withLock {
timeout.awaitSignal(condition) // Returns early but doesn't throw.
}
timeout.awaitSignal(condition) // Returns early but doesn't throw.
assertElapsed(1000.0, start)
}

Expand All @@ -225,4 +215,10 @@ class AwaitSignalTest {
TimeUnit.MILLISECONDS,
)
}

companion object {
@Parameters(name = "{0}")
@JvmStatic
fun parameters(): List<Array<out Any?>> = TimeoutFactory.entries.map { arrayOf(it) }
}
}
33 changes: 33 additions & 0 deletions okio/src/jvmTest/kotlin/okio/TimeoutFactory.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Copyright (C) 2023 Square, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package okio

enum class TimeoutFactory {
BASE {
override fun newTimeout() = Timeout()
},

FORWARDING {
override fun newTimeout() = ForwardingTimeout(BASE.newTimeout())
},

ASYNC {
override fun newTimeout() = AsyncTimeout()
},
;

abstract fun newTimeout(): Timeout
}
28 changes: 15 additions & 13 deletions okio/src/jvmTest/kotlin/okio/WaitUntilNotifiedTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,15 @@ import org.junit.Assert.assertEquals
import org.junit.Assert.assertTrue
import org.junit.Assert.fail
import org.junit.Test

class WaitUntilNotifiedTest {
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
import org.junit.runners.Parameterized.Parameters

@RunWith(Parameterized::class)
class WaitUntilNotifiedTest(
factory: TimeoutFactory,
) {
private val timeout = factory.newTimeout()
private val executorService = newScheduledExecutorService(0)

@After
Expand All @@ -36,7 +43,6 @@ class WaitUntilNotifiedTest {
@Test
@Synchronized
fun notified() {
val timeout = Timeout()
timeout.timeout(5000, TimeUnit.MILLISECONDS)
val start = now()
executorService.schedule(
Expand All @@ -56,7 +62,6 @@ class WaitUntilNotifiedTest {
@Synchronized
fun timeout() {
assumeNotWindows()
val timeout = Timeout()
timeout.timeout(1000, TimeUnit.MILLISECONDS)
val start = now()
try {
Expand All @@ -72,7 +77,6 @@ class WaitUntilNotifiedTest {
@Synchronized
fun deadline() {
assumeNotWindows()
val timeout = Timeout()
timeout.deadline(1000, TimeUnit.MILLISECONDS)
val start = now()
try {
Expand All @@ -88,7 +92,6 @@ class WaitUntilNotifiedTest {
@Synchronized
fun deadlineBeforeTimeout() {
assumeNotWindows()
val timeout = Timeout()
timeout.timeout(5000, TimeUnit.MILLISECONDS)
timeout.deadline(1000, TimeUnit.MILLISECONDS)
val start = now()
Expand All @@ -105,7 +108,6 @@ class WaitUntilNotifiedTest {
@Synchronized
fun timeoutBeforeDeadline() {
assumeNotWindows()
val timeout = Timeout()
timeout.timeout(1000, TimeUnit.MILLISECONDS)
timeout.deadline(5000, TimeUnit.MILLISECONDS)
val start = now()
Expand All @@ -122,7 +124,6 @@ class WaitUntilNotifiedTest {
@Synchronized
fun deadlineAlreadyReached() {
assumeNotWindows()
val timeout = Timeout()
timeout.deadlineNanoTime(System.nanoTime())
val start = now()
try {
Expand All @@ -138,7 +139,6 @@ class WaitUntilNotifiedTest {
@Synchronized
fun threadInterrupted() {
assumeNotWindows()
val timeout = Timeout()
val start = now()
Thread.currentThread().interrupt()
try {
Expand All @@ -155,7 +155,6 @@ class WaitUntilNotifiedTest {
@Synchronized
fun threadInterruptedOnThrowIfReached() {
assumeNotWindows()
val timeout = Timeout()
Thread.currentThread().interrupt()
try {
timeout.throwIfReached()
Expand All @@ -170,7 +169,6 @@ class WaitUntilNotifiedTest {
@Synchronized
fun cancelBeforeWaitDoesNothing() {
assumeNotWindows()
val timeout = Timeout()
timeout.timeout(1000, TimeUnit.MILLISECONDS)
timeout.cancel()
val start = now()
Expand All @@ -186,7 +184,6 @@ class WaitUntilNotifiedTest {
@Test
@Synchronized
fun canceledTimeoutDoesNotThrowWhenNotNotifiedOnTime() {
val timeout = Timeout()
timeout.timeout(1000, TimeUnit.MILLISECONDS)
timeout.cancelLater(500)

Expand All @@ -198,7 +195,6 @@ class WaitUntilNotifiedTest {
@Test
@Synchronized
fun multipleCancelsAreIdempotent() {
val timeout = Timeout()
timeout.timeout(1000, TimeUnit.MILLISECONDS)
timeout.cancelLater(250)
timeout.cancelLater(500)
Expand Down Expand Up @@ -231,4 +227,10 @@ class WaitUntilNotifiedTest {
TimeUnit.MILLISECONDS,
)
}

companion object {
@Parameters(name = "{0}")
@JvmStatic
fun parameters(): List<Array<out Any?>> = TimeoutFactory.entries.map { arrayOf(it) }
}
}

0 comments on commit d2d0b57

Please sign in to comment.