Skip to content

Commit

Permalink
Add WorkQueue for use by SpinGroup
Browse files Browse the repository at this point in the history
  • Loading branch information
tjwp committed Nov 5, 2024
1 parent aa0f3d1 commit b89613d
Show file tree
Hide file tree
Showing 5 changed files with 396 additions and 17 deletions.
49 changes: 33 additions & 16 deletions lib/cli/ui/spinner/spin_group.rb
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# typed: true

require_relative '../work_queue'

module CLI
module UI
module Spinner
Expand Down Expand Up @@ -47,6 +49,8 @@ def pause_spinners(&block)
# ==== Options
#
# * +:auto_debrief+ - Automatically debrief exceptions or through success_debrief? Default to true
# * +:max_concurrent+ - Maximum number of concurrent tasks. Default is 0 (effectively unlimited)
# * +:work_queue+ - Custom WorkQueue instance. If not provided, a new one will be created
#
# ==== Example Usage
#
Expand All @@ -59,12 +63,23 @@ def pause_spinners(&block)
#
# https://user-images.githubusercontent.com/3074765/33798558-c452fa26-dce8-11e7-9e90-b4b34df21a46.gif
#
sig { params(auto_debrief: T::Boolean).void }
def initialize(auto_debrief: true)
sig do
params(
auto_debrief: T::Boolean,
max_concurrent: Integer,
work_queue: T.nilable(WorkQueue),
).void
end
def initialize(auto_debrief: true, max_concurrent: 0, work_queue: nil)
@m = Mutex.new
@tasks = []
@auto_debrief = auto_debrief
@start = Time.new
@internal_work_queue = work_queue.nil?
@work_queue = T.let(
work_queue || WorkQueue.new(max_concurrent.zero? ? 1024 : max_concurrent),
WorkQueue,
)
if block_given?
yield self
wait
Expand Down Expand Up @@ -97,14 +112,15 @@ class Task
final_glyph: T.proc.params(success: T::Boolean).returns(T.any(Glyph, String)),
merged_output: T::Boolean,
duplicate_output_to: IO,
work_queue: WorkQueue,
block: T.proc.params(task: Task).returns(T.untyped),
).void
end
def initialize(title, final_glyph:, merged_output:, duplicate_output_to:, &block)
def initialize(title, final_glyph:, merged_output:, duplicate_output_to:, work_queue:, &block)
@title = title
@final_glyph = final_glyph
@always_full_render = title =~ Formatter::SCAN_WIDGET
@thread = Thread.new do
@future = work_queue.enqueue do
cap = CLI::UI::StdoutRouter::Capture.new(
merged_output: merged_output, duplicate_output_to: duplicate_output_to,
) { block.call(self) }
Expand All @@ -120,21 +136,21 @@ def initialize(title, final_glyph:, merged_output:, duplicate_output_to:, &block
@force_full_render = false
@done = false
@exception = nil
@success = false
@success = false
end

# Checks if a task is finished
#
sig { returns(T::Boolean) }
def check
return true if @done
return false if @thread.alive?
return false unless @future.completed?

@done = true
begin
status = @thread.join.status
@success = (status == false)
@success = false if @thread.value == TASK_FAILED
result = @future.value
@success = true
@success = false if result == TASK_FAILED
rescue => exc
@exception = exc
@success = false
Expand Down Expand Up @@ -189,11 +205,6 @@ def update_title(new_title)
end
end

sig { void }
def interrupt
@thread.raise(Interrupt)
end

private

sig { params(index: Integer, terminal_width: Integer).returns(String) }
Expand Down Expand Up @@ -232,7 +243,11 @@ def glyph(index)
final_glyph
end
elsif CLI::UI.enable_cursor?
CLI::UI.enable_color? ? GLYPHS[index] : RUNES[index]
if !@future.started?
CLI::UI.enable_color? ? Glyph::CHEVRON.to_s : Glyph::CHEVRON.char
else
CLI::UI.enable_color? ? GLYPHS[index] : RUNES[index]
end
else
Glyph::HOURGLASS.char
end
Expand Down Expand Up @@ -283,6 +298,7 @@ def add(
final_glyph: final_glyph,
merged_output: merged_output,
duplicate_output_to: duplicate_output_to,
work_queue: @work_queue,
&block
)
end
Expand Down Expand Up @@ -348,13 +364,14 @@ def wait
sleep(PERIOD)
end

@work_queue.wait if @internal_work_queue
if @auto_debrief
debrief
else
all_succeeded?
end
rescue Interrupt
@tasks.each(&:interrupt)
@work_queue.interrupt
raise
end

Expand Down
140 changes: 140 additions & 0 deletions lib/cli/ui/work_queue.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# typed: strict
# frozen_string_literal: true

module CLI
module UI
class WorkQueue
extend T::Sig

class Future
extend T::Sig

sig { void }
def initialize
@mutex = T.let(Mutex.new, Mutex)
@condition = T.let(ConditionVariable.new, ConditionVariable)
@completed = T.let(false, T::Boolean)
@started = T.let(false, T::Boolean)
@result = T.let(nil, T.untyped)
@error = T.let(nil, T.nilable(Exception))
end

sig { params(result: T.untyped).void }
def complete(result)
@mutex.synchronize do
@completed = true
@result = result
@condition.broadcast
end
end

sig { params(error: Exception).void }
def fail(error)
@mutex.synchronize do
return if @completed

@completed = true
@error = error
@condition.broadcast
end
end

sig { returns(T.untyped) }
def value
@mutex.synchronize do
@condition.wait(@mutex) until @completed
raise @error if @error

@result
end
end

sig { returns(T::Boolean) }
def completed?
@mutex.synchronize { @completed }
end

sig { returns(T::Boolean) }
def started?
@mutex.synchronize { @started }
end

sig { void }
def start
@mutex.synchronize do
@started = true
@condition.broadcast
end
end
end

sig { params(max_concurrent: Integer).void }
def initialize(max_concurrent)
@max_concurrent = max_concurrent
@queue = T.let(Queue.new, Queue)
@mutex = T.let(Mutex.new, Mutex)
@condition = T.let(ConditionVariable.new, ConditionVariable)
@workers = T.let([], T::Array[Thread])
end

sig { params(block: T.proc.returns(T.untyped)).returns(Future) }
def enqueue(&block)
future = Future.new
@mutex.synchronize do
start_worker if @workers.size < @max_concurrent
end
@queue.push([future, block])
future
end

sig { void }
def wait
@queue.close
@workers.each(&:join)
end

sig { void }
def interrupt
@mutex.synchronize do
@queue.close
# Fail any remaining tasks in the queue
until @queue.empty?
future, _block = @queue.pop(true)
future&.fail(Interrupt.new)
end
# Interrupt all worker threads
@workers.each { |worker| worker.raise(Interrupt) if worker.alive? }
@workers.clear
end
end

private

sig { void }
def start_worker
@workers << Thread.new do
loop do
work = @queue.pop
break if work.nil?

future, block = work

begin
future.start
result = block.call
future.complete(result)
rescue Interrupt => e
future.fail(e)
raise # Always re-raise interrupts to terminate the worker
rescue StandardError => e
future.fail(e)
# Don't re-raise standard errors - allow worker to continue
end
end
rescue Interrupt
# Clean exit on interrupt
end
end
end
end
end
73 changes: 73 additions & 0 deletions test/cli/ui/spinner/spin_group_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,79 @@ def test_spin_group_success_debrief
assert(sg.wait)
end
end

def test_spin_group_with_custom_work_queue
capture_io do
CLI::UI::StdoutRouter.ensure_activated
work_queue = CLI::UI::WorkQueue.new(2)
sg = SpinGroup.new(work_queue: work_queue)

tasks_executed = 0
3.times do |i|
sg.add("Task #{i + 1}") do
tasks_executed += 1
sleep(0.1)
true
end
end

assert(sg.wait)
assert_equal(3, tasks_executed)
end
end

def test_spin_group_with_max_concurrent
capture_io do
CLI::UI::StdoutRouter.ensure_activated
sg = SpinGroup.new(max_concurrent: 2)

start_times = []
3.times do |i|
sg.add("Task #{i + 1}") do
start_times << Time.now
sleep(0.2)
true
end
end

assert(sg.wait)
assert_equal(3, start_times.size)
assert(start_times[2] - start_times[0] >= 0.2, 'Third task should start after the first one finishes')
end
end

def test_spin_group_interrupt
capture_io do
CLI::UI::StdoutRouter.ensure_activated
sg = SpinGroup.new
task_completed = false
task_interrupted = false

# Use Queue for thread-safe signaling
started_queue = Queue.new

sg.add('Interruptible task') do
started_queue.push(true)
10.times { sleep(0.1) }
task_completed = true
rescue Interrupt
task_interrupted = true
raise
end

t = Thread.new { sg.wait }

# Wait for task to start
started_queue.pop
sleep(0.1) # Small delay to ensure we're in sleep
t.raise(Interrupt)
sleep(0.1) # Small delay to react to Interrupt

assert_raises(Interrupt) { t.join }
refute(task_completed, 'Task should not have completed')
assert(task_interrupted, 'Task should have been interrupted')
end
end
end
end
end
Expand Down
1 change: 0 additions & 1 deletion test/cli/ui/spinner_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ def test_spinner_task_error_through_returning_error
CLI::UI::Spinner::TASK_FAILED
end
end

assert_match(/✗/, out)
assert_match(/Task Failed: broken/, out)
assert_match(/STDERR[^\n]*\n[^\n]*not empty/, out)
Expand Down
Loading

0 comments on commit b89613d

Please sign in to comment.