diff --git a/spec/std/wait_group_spec.cr b/spec/std/wait_group_spec.cr index 459af8d5c898..6c2f46daa562 100644 --- a/spec/std/wait_group_spec.cr +++ b/spec/std/wait_group_spec.cr @@ -160,6 +160,19 @@ describe WaitGroup do extra.get.should eq(32) end + it "takes a block to WaitGroup.wait" do + fiber_count = 10 + completed = Array.new(fiber_count) { false } + + WaitGroup.wait do |wg| + fiber_count.times do |i| + wg.spawn { completed[i] = true } + end + end + + completed.should eq [true] * 10 + end + # the test takes far too much time for the interpreter to complete {% unless flag?(:interpreted) %} it "stress add/done/wait" do diff --git a/src/wait_group.cr b/src/wait_group.cr index 2fd49c593b56..89510714c727 100644 --- a/src/wait_group.cr +++ b/src/wait_group.cr @@ -42,12 +42,46 @@ class WaitGroup end end + # Yields a `WaitGroup` instance and waits at the end of the block for all of + # the work enqueued inside it to complete. + # + # ``` + # WaitGroup.wait do |wg| + # items.each do |item| + # wg.spawn { process item } + # end + # end + # ``` + def self.wait : Nil + instance = new + yield instance + instance.wait + end + def initialize(n : Int32 = 0) @waiting = Crystal::PointerLinkedList(Waiting).new @lock = Crystal::SpinLock.new @counter = Atomic(Int32).new(n) end + # Increment the counter by 1, perform the work inside the block in a separate + # fiber, decrementing the counter after it completes or raises. Returns the + # `Fiber` that was spawned. + # + # ``` + # wg = WaitGroup.new + # wg.spawn { do_something } + # wg.wait + # ``` + def spawn(&block) : Fiber + add + ::spawn do + block.call + ensure + done + end + end + # Increments the counter by how many fibers we want to wait for. # # A negative value decrements the counter. When the counter reaches zero,