Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Atomic Barrier and PhaseLockingTestMixin #2772

Merged
merged 1 commit into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions spark/src/main/scala/org/apache/spark/sql/delta/BusyWait.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Copyright (2021) The Delta Lake Project Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.delta

import scala.concurrent.duration._

object BusyWait {
/**
* Keep checking if `check` returns `true` until it's the case or `waitTime` expires.
*
* Return `true` when the `check` returned `true`, and `false` if `waitTime` expired.
*
* Note: This function is used as a helper function for the Concurrency Testing framework,
* and should not be used in production code. Production code should not use polling
* and should instead use signalling to coordinate.
*/
def until(
check: => Boolean,
waitTime: FiniteDuration): Boolean = {
val DEFAULT_SLEEP_TIME: Duration = 10.millis
val deadline = waitTime.fromNow

do {
if (check) {
return true
}
val sleepTimeMs = DEFAULT_SLEEP_TIME.min(deadline.timeLeft).toMillis
Thread.sleep(sleepTimeMs)
} while (deadline.hasTimeLeft())
false
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
/*
* Copyright (2021) The Delta Lake Project Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.delta.fuzzer

import java.util.concurrent.atomic.AtomicInteger

/**
* An atomic barrier is similar to a countdown latch,
* except that the content is a state transition system with semantic meaning
* instead of a simple counter.
*
* It is designed with a single writer ("unblocker") thread and a single reader ("waiter") thread
* in mind. It is concurrency safe with more writers and readers, but using more is likely to cause
* race conditions for legal transitions. That is to say, trying to perform an otherwise
* legal transition twice is illegal and may occur if there is more than one unblocker or
* waiter thread.
* Having additional passive state observers that only call [[load()]] is never an issue.
*
* Legal transitions are:
* - BLOCKED -> UNBLOCKED
* - BLOCKED -> REQUESTED
* - REQUESTED -> UNBLOCKED
* - UNBLOCKED -> PASSED
*/
class AtomicBarrier {

import AtomicBarrier._

private final val state: AtomicInteger = new AtomicInteger(State.Blocked.ordinal)

/** Get the current state. */
def load(): State = {
val ordinal = state.get()
// We should never be putting illegal state ordinals into `state`,
// so this should always succeed.
stateIndex(ordinal)
}

/** Transition to the Unblocked state. */
def unblock(): Unit = {
// Just hot-retry this, since it never needs to wait to make progress.
var successful = false
while(!successful) {
val currentValue = state.get()
if (currentValue == State.Blocked.ordinal || currentValue == State.Requested.ordinal) {
this.synchronized {
successful = state.compareAndSet(currentValue, State.Unblocked.ordinal)
if (successful) {
this.notifyAll()
}
}
} else {
// if it's in any other state we will never make progress
throw new IllegalStateTransitionException(stateIndex(currentValue), State.Unblocked)
}
}
}

/** Wait until this barrier can be passed and then mark it as Passed. */
def waitToPass(): Unit = {
while (true) {
val currentState = load()
currentState match {
case State.Unblocked =>
val updated = state.compareAndSet(currentState.ordinal, State.Passed.ordinal)
if (updated) {
return
}
case State.Passed =>
throw new IllegalStateTransitionException(State.Passed, State.Passed)
case State.Requested =>
this.synchronized {
if (load().ordinal == State.Requested.ordinal) {
this.wait()
}
}
case State.Blocked =>
this.synchronized {
val updated = state.compareAndSet(currentState.ordinal, State.Requested.ordinal)
if (updated) {
this.wait()
}
} // else (if we didn't succeed) just hot-retry until we do
// (or more likely pass, since unblocking is the only legal concurrent
// update with a single concurrent "waiter")
}
}
}

override def toString: String = s"AtomicBarrier(state=${load()})"
}

object AtomicBarrier {

sealed trait State {
def ordinal: Int
}

object State {
case object Blocked extends State {
override final val ordinal = 0
}
case object Unblocked extends State {
override final val ordinal = 1
}
case object Requested extends State {
override final val ordinal = 2
}
case object Passed extends State {
override final val ordinal = 3
}
}

final val stateIndex: Map[Int, State] =
List(State.Blocked, State.Unblocked, State.Requested, State.Passed)
.map(state => state.ordinal -> state)
.toMap
}

class IllegalStateTransitionException(fromState: AtomicBarrier.State, toState: AtomicBarrier.State)
extends RuntimeException(s"State transition from $fromState to $toState is illegal.")
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Copyright (2021) The Delta Lake Project Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.delta.concurrency

import scala.concurrent.duration._

import org.apache.spark.sql.delta.BusyWait
import org.apache.spark.sql.delta.fuzzer.AtomicBarrier

import org.apache.spark.SparkFunSuite

trait PhaseLockingTestMixin { self: SparkFunSuite =>
/** Keep checking if `barrier` in `state` until it's the case or `waitTime` expires. */
def busyWaitForState(
barrier: AtomicBarrier,
state: AtomicBarrier.State,
waitTime: FiniteDuration): Unit =
busyWaitFor(
barrier.load() == state,
waitTime,
s"Exceeded deadline waiting for $barrier to transition to state $state")

/**
* Keep checking if `check` return `true` until it's the case or `waitTime` expires.
*
* Optionally provide a custom error `message`.
*/
def busyWaitFor(
check: => Boolean,
timeout: FiniteDuration,
// lazy evaluate so closed over states are evaluated at time of failure not invocation
message: => String = "Exceeded deadline waiting for check to become true."): Unit = {
if (!BusyWait.until(check, timeout)) {
fail(message)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Copyright (2021) The Delta Lake Project Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.delta.fuzzer

import scala.concurrent.duration._

import org.apache.spark.sql.delta.concurrency.PhaseLockingTestMixin

import org.apache.spark.SparkFunSuite

class AtomicBarrierSuite extends SparkFunSuite
with PhaseLockingTestMixin {

val timeout: FiniteDuration = 5000.millis

test("Atomic Barrier - wait before unblock") {
val barrier = new AtomicBarrier
assert(AtomicBarrier.State.Blocked === barrier.load())
val thread = new Thread(() => {
barrier.waitToPass()
})
assert(AtomicBarrier.State.Blocked === barrier.load())
thread.start()
busyWaitForState(barrier, AtomicBarrier.State.Requested, timeout)
assert(thread.isAlive) // should be stuck waiting for unblock
barrier.unblock()
busyWaitForState(barrier, AtomicBarrier.State.Passed, timeout)
thread.join(timeout.toMillis) // shouldn't take long
assert(!thread.isAlive) // should have passed the barrier and completed
}

test("Atomic Barrier - unblock before wait") {
val barrier = new AtomicBarrier
assert(AtomicBarrier.State.Blocked === barrier.load())
val thread = new Thread(() => {
barrier.waitToPass()
})
assert(AtomicBarrier.State.Blocked === barrier.load())
barrier.unblock()
assert(AtomicBarrier.State.Unblocked === barrier.load())
thread.start()
busyWaitForState(barrier, AtomicBarrier.State.Passed, timeout)
thread.join(timeout.toMillis) // shouldn't take long
assert(!thread.isAlive) // should have passed the barrier and completed
}
}
Loading