From 222859466d7f351639a71cf1cad18bec3cd4dcae Mon Sep 17 00:00:00 2001 From: Eirik Tsarpalis Date: Tue, 9 Nov 2021 12:05:54 +0000 Subject: [PATCH] Ensure MaxBy/MinBy return first element if all keys are null. --- .../System.Linq/src/System/Linq/Max.cs | 19 ++++++++++----- .../System.Linq/src/System/Linq/Min.cs | 19 ++++++++++----- src/libraries/System.Linq/tests/MaxTests.cs | 24 +++++++++---------- src/libraries/System.Linq/tests/MinTests.cs | 24 +++++++++---------- 4 files changed, 50 insertions(+), 36 deletions(-) diff --git a/src/libraries/System.Linq/src/System/Linq/Max.cs b/src/libraries/System.Linq/src/System/Linq/Max.cs index 6a847068c5fa6..2c05d2dc29dbb 100644 --- a/src/libraries/System.Linq/src/System/Linq/Max.cs +++ b/src/libraries/System.Linq/src/System/Linq/Max.cs @@ -583,15 +583,22 @@ public static decimal Max(this IEnumerable source) if (default(TKey) is null) { - while (key == null) + if (key == null) { - if (!e.MoveNext()) + TSource firstValue = value; + + do { - return value; - } + if (!e.MoveNext()) + { + // All keys are null, surface the first element. + return firstValue; + } - value = e.Current; - key = keySelector(value); + value = e.Current; + key = keySelector(value); + } + while (key == null); } while (e.MoveNext()) diff --git a/src/libraries/System.Linq/src/System/Linq/Min.cs b/src/libraries/System.Linq/src/System/Linq/Min.cs index a66b0cfcf88fd..9f9266ff1bb83 100644 --- a/src/libraries/System.Linq/src/System/Linq/Min.cs +++ b/src/libraries/System.Linq/src/System/Linq/Min.cs @@ -541,15 +541,22 @@ public static decimal Min(this IEnumerable source) if (default(TKey) is null) { - while (key == null) + if (key == null) { - if (!e.MoveNext()) + TSource firstValue = value; + + do { - return value; - } + if (!e.MoveNext()) + { + // All keys are null, surface the first element. + return firstValue; + } - value = e.Current; - key = keySelector(value); + value = e.Current; + key = keySelector(value); + } + while (key == null); } while (e.MoveNext()) diff --git a/src/libraries/System.Linq/tests/MaxTests.cs b/src/libraries/System.Linq/tests/MaxTests.cs index c69945636c467..78318b64b1194 100644 --- a/src/libraries/System.Linq/tests/MaxTests.cs +++ b/src/libraries/System.Linq/tests/MaxTests.cs @@ -890,27 +890,27 @@ public static void MaxBy_Generic_EmptyReferenceSource_ReturnsNull() } [Fact] - public static void MaxBy_Generic_StructSourceAllKeysAreNull_ReturnsLastElement() + public static void MaxBy_Generic_StructSourceAllKeysAreNull_ReturnsFirstElement() { - Assert.Equal(4, Enumerable.Range(0, 5).MaxBy(x => default(string))); - Assert.Equal(4, Enumerable.Range(0, 5).MaxBy(x => default(string), comparer: null)); - Assert.Equal(4, Enumerable.Range(0, 5).MaxBy(x => default(string), Comparer.Create((_, _) => throw new InvalidOperationException("comparer should not be called.")))); + Assert.Equal(0, Enumerable.Range(0, 5).MaxBy(x => default(string))); + Assert.Equal(0, Enumerable.Range(0, 5).MaxBy(x => default(string), comparer: null)); + Assert.Equal(0, Enumerable.Range(0, 5).MaxBy(x => default(string), Comparer.Create((_, _) => throw new InvalidOperationException("comparer should not be called.")))); } [Fact] - public static void MaxBy_Generic_NullableSourceAllKeysAreNull_ReturnsLastElement() + public static void MaxBy_Generic_NullableSourceAllKeysAreNull_ReturnsFirstElement() { - Assert.Equal(4, Enumerable.Range(0, 5).Cast().MaxBy(x => default(int?))); - Assert.Equal(4, Enumerable.Range(0, 5).Cast().MaxBy(x => default(int?), comparer: null)); - Assert.Equal(4, Enumerable.Range(0, 5).Cast().MaxBy(x => default(int?), Comparer.Create((_, _) => throw new InvalidOperationException("comparer should not be called.")))); + Assert.Equal(0, Enumerable.Range(0, 5).Cast().MaxBy(x => default(int?))); + Assert.Equal(0, Enumerable.Range(0, 5).Cast().MaxBy(x => default(int?), comparer: null)); + Assert.Equal(0, Enumerable.Range(0, 5).Cast().MaxBy(x => default(int?), Comparer.Create((_, _) => throw new InvalidOperationException("comparer should not be called.")))); } [Fact] - public static void MaxBy_Generic_ReferenceSourceAllKeysAreNull_ReturnsLastElement() + public static void MaxBy_Generic_ReferenceSourceAllKeysAreNull_ReturnsFirstElement() { - Assert.Equal("4", Enumerable.Range(0, 5).Select(x => x.ToString()).MaxBy(x => default(string))); - Assert.Equal("4", Enumerable.Range(0, 5).Select(x => x.ToString()).MaxBy(x => default(string), comparer: null)); - Assert.Equal("4", Enumerable.Range(0, 5).Select(x => x.ToString()).MaxBy(x => default(string), Comparer.Create((_, _) => throw new InvalidOperationException("comparer should not be called.")))); + Assert.Equal("0", Enumerable.Range(0, 5).Select(x => x.ToString()).MaxBy(x => default(string))); + Assert.Equal("0", Enumerable.Range(0, 5).Select(x => x.ToString()).MaxBy(x => default(string), comparer: null)); + Assert.Equal("0", Enumerable.Range(0, 5).Select(x => x.ToString()).MaxBy(x => default(string), Comparer.Create((_, _) => throw new InvalidOperationException("comparer should not be called.")))); } [Theory] diff --git a/src/libraries/System.Linq/tests/MinTests.cs b/src/libraries/System.Linq/tests/MinTests.cs index ab6af55c36e78..31296da12c691 100644 --- a/src/libraries/System.Linq/tests/MinTests.cs +++ b/src/libraries/System.Linq/tests/MinTests.cs @@ -868,27 +868,27 @@ public static void MinBy_Generic_EmptyReferenceSource_ReturnsNull() } [Fact] - public static void MinBy_Generic_StructSourceAllKeysAreNull_ReturnsLastElement() + public static void MinBy_Generic_StructSourceAllKeysAreNull_ReturnsFirstElement() { - Assert.Equal(4, Enumerable.Range(0, 5).MinBy(x => default(string))); - Assert.Equal(4, Enumerable.Range(0, 5).MinBy(x => default(string), comparer: null)); - Assert.Equal(4, Enumerable.Range(0, 5).MinBy(x => default(string), Comparer.Create((_, _) => throw new InvalidOperationException("comparer should not be called.")))); + Assert.Equal(0, Enumerable.Range(0, 5).MinBy(x => default(string))); + Assert.Equal(0, Enumerable.Range(0, 5).MinBy(x => default(string), comparer: null)); + Assert.Equal(0, Enumerable.Range(0, 5).MinBy(x => default(string), Comparer.Create((_, _) => throw new InvalidOperationException("comparer should not be called.")))); } [Fact] - public static void MinBy_Generic_NullableSourceAllKeysAreNull_ReturnsLastElement() + public static void MinBy_Generic_NullableSourceAllKeysAreNull_ReturnsFirstElement() { - Assert.Equal(4, Enumerable.Range(0, 5).Cast().MinBy(x => default(int?))); - Assert.Equal(4, Enumerable.Range(0, 5).Cast().MinBy(x => default(int?), comparer: null)); - Assert.Equal(4, Enumerable.Range(0, 5).Cast().MinBy(x => default(int?), Comparer.Create((_, _) => throw new InvalidOperationException("comparer should not be called.")))); + Assert.Equal(0, Enumerable.Range(0, 5).Cast().MinBy(x => default(int?))); + Assert.Equal(0, Enumerable.Range(0, 5).Cast().MinBy(x => default(int?), comparer: null)); + Assert.Equal(0, Enumerable.Range(0, 5).Cast().MinBy(x => default(int?), Comparer.Create((_, _) => throw new InvalidOperationException("comparer should not be called.")))); } [Fact] - public static void MinBy_Generic_ReferenceSourceAllKeysAreNull_ReturnsLastElement() + public static void MinBy_Generic_ReferenceSourceAllKeysAreNull_ReturnsFirstElement() { - Assert.Equal("4", Enumerable.Range(0, 5).Select(x => x.ToString()).MinBy(x => default(string))); - Assert.Equal("4", Enumerable.Range(0, 5).Select(x => x.ToString()).MinBy(x => default(string), comparer: null)); - Assert.Equal("4", Enumerable.Range(0, 5).Select(x => x.ToString()).MinBy(x => default(string), Comparer.Create((_, _) => throw new InvalidOperationException("comparer should not be called.")))); + Assert.Equal("0", Enumerable.Range(0, 5).Select(x => x.ToString()).MinBy(x => default(string))); + Assert.Equal("0", Enumerable.Range(0, 5).Select(x => x.ToString()).MinBy(x => default(string), comparer: null)); + Assert.Equal("0", Enumerable.Range(0, 5).Select(x => x.ToString()).MinBy(x => default(string), Comparer.Create((_, _) => throw new InvalidOperationException("comparer should not be called.")))); } [Theory]