Skip to content

Commit

Permalink
Merge pull request #3346 from BalmungSan/improve-mutex
Browse files Browse the repository at this point in the history
Optimize `Mutex` & `AtomicCell`
  • Loading branch information
djspiewak authored Feb 7, 2023
2 parents 3cc94c2 + 1d7468d commit 2f3ed2a
Show file tree
Hide file tree
Showing 7 changed files with 415 additions and 14 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/*
* Copyright 2020-2022 Typelevel
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package cats.effect.benchmarks

import cats.effect.IO
import cats.effect.std._
import cats.effect.unsafe.implicits.global
import cats.syntax.all._

import org.openjdk.jmh.annotations._

import java.util.concurrent.TimeUnit

/**
* To do comparative benchmarks between versions:
*
* benchmarks/run-benchmark AtomicCellBenchmark
*
* This will generate results in `benchmarks/results`.
*
* Or to run the benchmark from within sbt:
*
* Jmh / run -i 10 -wi 10 -f 2 -t 1 cats.effect.benchmarks.AtomicCellBenchmark
*
* Which means "10 iterations", "10 warm-up iterations", "2 forks", "1 thread". Please note that
* benchmarks should be usually executed at least in 10 iterations (as a rule of thumb), but
* more is better.
*/
@State(Scope.Thread)
@BenchmarkMode(Array(Mode.Throughput))
@OutputTimeUnit(TimeUnit.SECONDS)
class AtomicCellBenchmark {
@Param(Array("10", "50", "100"))
var fibers: Int = _

@Param(Array("1000"))
var iterations: Int = _

private def happyPathImpl(cell: IO[AtomicCell[IO, Int]]): Unit = {
cell
.flatMap { c => c.evalUpdate(i => IO(i + 1)).replicateA_(fibers) }
.replicateA_(iterations)
.unsafeRunSync()
}

@Benchmark
def happyPathConcurrent(): Unit = {
happyPathImpl(cell = AtomicCell.concurrent(0))
}

@Benchmark
def happyPathAsync(): Unit = {
happyPathImpl(cell = AtomicCell.async(0))
}

private def highContentionImpl(cell: IO[AtomicCell[IO, Int]]): Unit = {
cell
.flatMap { c => c.evalUpdate(i => IO(i + 1)).parReplicateA_(fibers) }
.replicateA_(iterations)
.unsafeRunSync()
}

@Benchmark
def highContentionConcurrent(): Unit = {
highContentionImpl(cell = AtomicCell.concurrent(0))
}

@Benchmark
def highContentionAsync(): Unit = {
highContentionImpl(cell = AtomicCell.async(0))
}

private def cancellationImpl(cell: IO[AtomicCell[IO, Int]]): Unit = {
cell
.flatMap { c =>
c.evalUpdate { _ =>
c.evalUpdate(i => IO(i + 1))
.start
.flatMap(fiber => IO.cede >> fiber.cancel)
.parReplicateA_(fibers)
.as(-1)
}
}
.replicateA_(iterations)
.unsafeRunSync()
}

@Benchmark
def cancellationConcurrent(): Unit = {
cancellationImpl(cell = AtomicCell.concurrent(0))
}

@Benchmark
def cancellationAsync(): Unit = {
cancellationImpl(cell = AtomicCell.async(0))
}
}
107 changes: 107 additions & 0 deletions benchmarks/src/main/scala/cats/effect/benchmarks/MutexBenchmark.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/*
* Copyright 2020-2022 Typelevel
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package cats.effect.benchmarks

import cats.effect.IO
import cats.effect.std._
import cats.effect.unsafe.implicits.global
import cats.syntax.all._

import org.openjdk.jmh.annotations._

import java.util.concurrent.TimeUnit

/**
* To do comparative benchmarks between versions:
*
* benchmarks/run-benchmark MutexBenchmark
*
* This will generate results in `benchmarks/results`.
*
* Or to run the benchmark from within sbt:
*
* Jmh / run -i 10 -wi 10 -f 2 -t 1 cats.effect.benchmarks.MutexBenchmark
*
* Which means "10 iterations", "10 warm-up iterations", "2 forks", "1 thread". Please note that
* benchmarks should be usually executed at least in 10 iterations (as a rule of thumb), but
* more is better.
*/
@State(Scope.Thread)
@BenchmarkMode(Array(Mode.Throughput))
@OutputTimeUnit(TimeUnit.SECONDS)
class MutexBenchmark {
@Param(Array("10", "50", "100"))
var fibers: Int = _

@Param(Array("1000"))
var iterations: Int = _

private def happyPathImpl(mutex: IO[Mutex[IO]]): Unit = {
mutex
.flatMap { m => m.lock.use_.replicateA_(fibers) }
.replicateA_(iterations)
.unsafeRunSync()
}

@Benchmark
def happyPathConcurrent(): Unit = {
happyPathImpl(mutex = Mutex.concurrent)
}

@Benchmark
def happyPathAsync(): Unit = {
happyPathImpl(mutex = Mutex.async)
}

private def highContentionImpl(mutex: IO[Mutex[IO]]): Unit = {
mutex
.flatMap { m => m.lock.use_.parReplicateA_(fibers) }
.replicateA_(iterations)
.unsafeRunSync()
}

@Benchmark
def highContentionConcurrent(): Unit = {
highContentionImpl(mutex = Mutex.concurrent)
}

@Benchmark
def highContentionAsync(): Unit = {
highContentionImpl(mutex = Mutex.async)
}

private def cancellationImpl(mutex: IO[Mutex[IO]]): Unit = {
mutex
.flatMap { m =>
m.lock.surround {
m.lock.use_.start.flatMap(fiber => IO.cede >> fiber.cancel).parReplicateA_(fibers)
}
}
.replicateA_(iterations)
.unsafeRunSync()
}

@Benchmark
def cancellationConcurrent(): Unit = {
cancellationImpl(mutex = Mutex.concurrent)
}

@Benchmark
def cancellationAsync(): Unit = {
cancellationImpl(mutex = Mutex.async)
}
}
4 changes: 4 additions & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -904,6 +904,10 @@ lazy val std = crossProject(JSPlatform, JVMPlatform, NativePlatform)
"cats.effect.std.Queue#CircularBufferQueue.onOfferNoCapacity"),
ProblemFilters.exclude[DirectMissingMethodProblem](
"cats.effect.std.Queue#DroppingQueue.onOfferNoCapacity"),
// introduced by #3346
// private stuff
ProblemFilters.exclude[MissingClassProblem](
"cats.effect.std.Mutex$Impl"),
// introduced by #3347
// private stuff
ProblemFilters.exclude[MissingClassProblem]("cats.effect.std.AtomicCell$Impl")
Expand Down
62 changes: 58 additions & 4 deletions std/shared/src/main/scala/cats/effect/std/AtomicCell.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ import cats.syntax.all._
* calling fiber.
*
* {{{
* final class ParkingLot(data: AtomicCell[IO, ArraySeq[Boolean]], rnd: Random[IO]) {
* final class ParkingLot(data: AtomicCell[IO, Vector[Boolean]], rnd: Random[IO]) {
* def getSpot: IO[Option[Int]] =
* data.evalModify { spots =>
* val availableSpots = spots.zipWithIndex.collect { case (true, idx) => idx }
* rnd.shuffleList(availableSpots).map { shuffled =>
* rnd.shuffleVector(availableSpots).map { shuffled =>
* val acquired = shuffled.headOption
* val next = acquired.fold(spots)(a => spots.updated(a, false)) // mark the chosen spot as taken
* (next, shuffled.headOption)
Expand Down Expand Up @@ -131,7 +131,12 @@ object AtomicCell {
* Initializes the `AtomicCell` using the provided value.
*/
def of[A](init: A)(implicit F: Concurrent[F]): F[AtomicCell[F, A]] =
AtomicCell.concurrent(init)
F match {
case f: Async[F] =>
AtomicCell.async(init)(f)
case _ =>
AtomicCell.concurrent(init)
}

@deprecated("Use the version that only requires Concurrent", since = "3.5.0")
private[std] def of[A](init: A, F: Async[F]): F[AtomicCell[F, A]] =
Expand All @@ -148,9 +153,12 @@ object AtomicCell {
of(M.empty)(F)
}

private[effect] def async[F[_], A](init: A)(implicit F: Async[F]): F[AtomicCell[F, A]] =
Mutex.async[F].map(mutex => new AsyncImpl(init, mutex))

private[effect] def concurrent[F[_], A](init: A)(
implicit F: Concurrent[F]): F[AtomicCell[F, A]] =
(Ref.of[F, A](init), Mutex[F]).mapN { (ref, m) => new ConcurrentImpl(ref, m) }
(Ref.of[F, A](init), Mutex.concurrent[F]).mapN { (ref, m) => new ConcurrentImpl(ref, m) }

private final class ConcurrentImpl[F[_], A](
ref: Ref[F, A],
Expand Down Expand Up @@ -184,4 +192,50 @@ object AtomicCell {
override def evalUpdateAndGet(f: A => F[A]): F[A] =
evalModify(a => f(a).map(aa => (aa, aa)))
}

private final class AsyncImpl[F[_], A](
init: A,
mutex: Mutex[F]
)(
implicit F: Async[F]
) extends AtomicCell[F, A] {
private var cell: A = init

override def get: F[A] =
mutex.lock.surround {
F.delay {
cell
}
}

override def set(a: A): F[Unit] =
mutex.lock.surround {
F.delay {
cell = a
}
}

override def modify[B](f: A => (A, B)): F[B] =
evalModify(a => F.pure(f(a)))

override def evalModify[B](f: A => F[(A, B)]): F[B] =
mutex.lock.surround {
F.delay(cell).flatMap(f).flatMap {
case (a, b) =>
F.delay {
cell = a
b
}
}
}

override def evalUpdate(f: A => F[A]): F[Unit] =
evalModify(a => f(a).map(aa => (aa, ())))

override def evalGetAndUpdate(f: A => F[A]): F[A] =
evalModify(a => f(a).map(aa => (aa, a)))

override def evalUpdateAndGet(f: A => F[A]): F[A] =
evalModify(a => f(a).map(aa => (aa, aa)))
}
}
Loading

0 comments on commit 2f3ed2a

Please sign in to comment.