Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Enumerable#min(count) and #max(count) #13057

Merged
merged 5 commits into from
Feb 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions spec/std/enumerable_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -803,18 +803,48 @@ describe "Enumerable" do

describe "max" do
it { [1, 2, 3].max.should eq(3) }
it { [1, 2, 3].max(0).should eq([] of Int32) }
it { [1, 2, 3].max(1).should eq([3]) }
it { [1, 2, 3].max(2).should eq([3, 2]) }
it { [1, 2, 3].max(3).should eq([3, 2, 1]) }
it { [1, 2, 3].max(4).should eq([3, 2, 1]) }
it { ([] of Int32).max(0).should eq([] of Int32) }
it { ([] of Int32).max(5).should eq([] of Int32) }
it {
(0..1000).map { |x| (x*137 + x*x*139) % 5000 }.max(10).should eq([
4992, 4990, 4980, 4972, 4962, 4962, 4960, 4960, 4952, 4952,
])
}

it "does not modify the array" do
xs = [7, 5, 2, 4, 9]
xs.max(2).should eq([9, 7])
xs.should eq([7, 5, 2, 4, 9])
end

it "raises if empty" do
expect_raises Enumerable::EmptyError do
([] of Int32).max
end
end

it "raises if n is negative" do
expect_raises ArgumentError do
([1, 2, 3] of Int32).max(-1)
end
end

it "raises if not comparable" do
expect_raises ArgumentError do
[Float64::NAN, 1.0, 2.0, Float64::NAN].max
end
end

it "raises if not comparable in max(n)" do
expect_raises ArgumentError do
[Float64::NAN, 1.0, 2.0, Float64::NAN].max(2)
end
end
end

describe "max?" do
Expand Down Expand Up @@ -851,18 +881,48 @@ describe "Enumerable" do

describe "min" do
it { [1, 2, 3].min.should eq(1) }
it { [1, 2, 3].min(0).should eq([] of Int32) }
it { [1, 2, 3].min(1).should eq([1]) }
it { [1, 2, 3].min(2).should eq([1, 2]) }
it { [1, 2, 3].min(3).should eq([1, 2, 3]) }
it { [1, 2, 3].min(4).should eq([1, 2, 3]) }
it { ([] of Int32).min(0).should eq([] of Int32) }
it { ([] of Int32).min(1).should eq([] of Int32) }
it {
(0..1000).map { |x| (x*137 + x*x*139) % 5000 }.min(10).should eq([
0, 10, 20, 26, 26, 26, 26, 30, 32, 32,
])
}

it "does not modify the array" do
xs = [7, 5, 2, 4, 9]
xs.min(2).should eq([2, 4])
xs.should eq([7, 5, 2, 4, 9])
end

it "raises if empty" do
expect_raises Enumerable::EmptyError do
([] of Int32).min
end
end

it "raises if n is negative" do
expect_raises ArgumentError do
([1, 2, 3] of Int32).min(-1)
end
end

it "raises if not comparable" do
expect_raises ArgumentError do
[-1.0, Float64::NAN, -3.0].min
end
end

it "raises if not comparable in min(n)" do
expect_raises ArgumentError do
[Float64::NAN, 1.0, 2.0, Float64::NAN].min(2)
end
end
end

describe "min?" do
Expand Down
77 changes: 77 additions & 0 deletions src/enumerable.cr
Original file line number Diff line number Diff line change
Expand Up @@ -965,6 +965,35 @@ module Enumerable(T)
ary
end

private def quickselect_internal(data : Array(T), left : Int, right : Int, k : Int) : T
loop do
return data[left] if left == right
pivot_index = left + (right - left)//2
pivot_index = quickselect_partition_internal(data, left, right, pivot_index)
if k == pivot_index
return data[k]
elsif k < pivot_index
right = pivot_index - 1
else
left = pivot_index + 1
end
end
end

private def quickselect_partition_internal(data : Array(T), left : Int, right : Int, pivot_index : Int) : Int
pivot_value = data[pivot_index]
data.swap(pivot_index, right)
store_index = left
(left...right).each do |i|
if compare_or_raise(data[i], pivot_value) < 0
data.swap(store_index, i)
store_index += 1
end
end
data.swap(right, store_index)
store_index
end

# Returns the element with the maximum value in the collection.
#
# It compares using `>` so it will work for any type that supports that method.
Expand All @@ -984,6 +1013,30 @@ module Enumerable(T)
max_by? &.itself
end

# Returns an array of the maximum *count* elements, sorted descending.
#
# It compares using `<=>` so it will work for any type that supports that method.
#
# ```
# [7, 5, 2, 4, 9].max(3) # => [9, 7, 5]
# %w[Eve Alice Bob Mallory Carol].max(2) # => ["Mallory", "Eve"]
# ```
#
# Returns all elements sorted descending if *count* is greater than the number
# of elements in the source.
#
# Raises `Enumerable::ArgumentError` if *count* is negative or if any elements
# are not comparable.
def max(count : Int) : Array(T)
raise ArgumentError.new("Count must be positive") if count < 0
data = self.is_a?(Array) ? self.dup : self.to_a
n = data.size
count = n if count > n
(0...count).map do |i|
quickselect_internal(data, 0, n - 1, n - 1 - i)
end
end

# Returns the element for which the passed block returns with the maximum value.
#
# It compares using `>` so the block must return a type that supports that method
Expand Down Expand Up @@ -1073,6 +1126,30 @@ module Enumerable(T)
min_by? &.itself
end

# Returns an array of the minimum *count* elements, sorted ascending.
#
# It compares using `<=>` so it will work for any type that supports that method.
#
# ```
# [7, 5, 2, 4, 9].min(3) # => [2, 4, 5]
# %w[Eve Alice Bob Mallory Carol].min(2) # => ["Alice", "Bob"]
# ```
#
# Returns all elements sorted ascending if *count* is greater than the number
# of elements in the source.
#
# Raises `Enumerable::ArgumentError` if *count* is negative or if any elements
# are not comparable.
def min(count : Int) : Array(T)
raise ArgumentError.new("Count must be positive") if count < 0
data = self.is_a?(Array) ? self.dup : self.to_a
n = data.size
count = n if count > n
(0...count).map do |i|
quickselect_internal(data, 0, n - 1, i)
end
end

# Returns the element for which the passed block returns with the minimum value.
#
# It compares using `<` so the block must return a type that supports that method
Expand Down