diff --git a/MoreLinq.Test/LagTest.cs b/MoreLinq.Test/LagTest.cs index dd864b9f9..331fae6f7 100644 --- a/MoreLinq.Test/LagTest.cs +++ b/MoreLinq.Test/LagTest.cs @@ -15,6 +15,8 @@ // limitations under the License. #endregion +#nullable enable + namespace MoreLinq.Test { using NUnit.Framework; @@ -131,5 +133,30 @@ public void TestLagPassesCorrectLagValuesOffsetBy2() Assert.IsTrue(result.Skip(2).All(x => x.B == (x.A - 2))); Assert.IsTrue(result.Take(2).All(x => (x.A - x.B) == x.A)); } + + [Test] + public void TestLagWithNullableReferences() + { + var words = new[] { "foo", "bar", "baz", "qux" }; + var result = words.Lag(2, (a, b) => new { A = a, B = b }); + result.AssertSequenceEqual( + new { A = "foo", B = (string?)null }, + new { A = "bar", B = (string?)null }, + new { A = "baz", B = (string?)"foo" }, + new { A = "qux", B = (string?)"bar" }); + } + + [Test] + public void TestLagWithNonNullableReferences() + { + var words = new[] { "foo", "bar", "baz", "qux" }; + var empty = string.Empty; + var result = words.Lag(2, empty, (a, b) => new { A = a, B = b }); + result.AssertSequenceEqual( + new { A = "foo", B = empty }, + new { A = "bar", B = empty }, + new { A = "baz", B = "foo" }, + new { A = "qux", B = "bar" }); + } } } diff --git a/MoreLinq.Test/LeadTest.cs b/MoreLinq.Test/LeadTest.cs index 4953a3afd..e0d4c592d 100644 --- a/MoreLinq.Test/LeadTest.cs +++ b/MoreLinq.Test/LeadTest.cs @@ -15,6 +15,8 @@ // limitations under the License. #endregion +#nullable enable + namespace MoreLinq.Test { using NUnit.Framework; @@ -133,5 +135,30 @@ public void TestLeadPassesCorrectValueOffsetBy2() Assert.IsTrue(result.Take(count - 2).All(x => x.B == (x.A + 2))); Assert.IsTrue(result.Skip(count - 2).All(x => x.B == leadDefault && (x.A == count || x.A == count - 1))); } + + [Test] + public void TestLagWithNullableReferences() + { + var words = new[] { "foo", "bar", "baz", "qux" }; + var result = words.Lead(2, (a, b) => new { A = a, B = b }); + result.AssertSequenceEqual( + new { A = "foo", B = (string?)"baz" }, + new { A = "bar", B = (string?)"qux" }, + new { A = "baz", B = (string?)null }, + new { A = "qux", B = (string?)null }); + } + + [Test] + public void TestLagWithNonNullableReferences() + { + var words = new[] { "foo", "bar", "baz", "qux" }; + var empty = string.Empty; + var result = words.Lead(2, empty, (a, b) => new { A = a, B = b }); + result.AssertSequenceEqual( + new { A = "foo", B = "baz" }, + new { A = "bar", B = "qux" }, + new { A = "baz", B = empty }, + new { A = "qux", B = empty }); + } } } diff --git a/MoreLinq/Extensions.g.cs b/MoreLinq/Extensions.g.cs index 4f94d77d3..0b4fa5e5c 100644 --- a/MoreLinq/Extensions.g.cs +++ b/MoreLinq/Extensions.g.cs @@ -3025,7 +3025,7 @@ public static partial class LagExtension /// A projection function which accepts the current and lagged items (in that order) and returns a result /// A sequence produced by projecting each element of the sequence with its lagged pairing - public static IEnumerable Lag(this IEnumerable source, int offset, Func resultSelector) + public static IEnumerable Lag(this IEnumerable source, int offset, Func resultSelector) => MoreEnumerable.Lag(source, offset, resultSelector); /// @@ -3114,7 +3114,7 @@ public static partial class LeadExtension /// A projection function which accepts the current and subsequent (lead) element (in that order) and produces a result /// A sequence produced by projecting each element of the sequence with its lead pairing - public static IEnumerable Lead(this IEnumerable source, int offset, Func resultSelector) + public static IEnumerable Lead(this IEnumerable source, int offset, Func resultSelector) => MoreEnumerable.Lead(source, offset, resultSelector); /// diff --git a/MoreLinq/Lag.cs b/MoreLinq/Lag.cs index 533bbeac7..66bd4a1e7 100644 --- a/MoreLinq/Lag.cs +++ b/MoreLinq/Lag.cs @@ -19,6 +19,7 @@ namespace MoreLinq { using System; using System.Collections.Generic; + using System.Linq; public static partial class MoreEnumerable { @@ -36,9 +37,13 @@ public static partial class MoreEnumerable /// A projection function which accepts the current and lagged items (in that order) and returns a result /// A sequence produced by projecting each element of the sequence with its lagged pairing - public static IEnumerable Lag(this IEnumerable source, int offset, Func resultSelector) + public static IEnumerable Lag(this IEnumerable source, int offset, Func resultSelector) { - return Lag(source, offset, default!, resultSelector); + if (source == null) throw new ArgumentNullException(nameof(source)); + if (resultSelector is null) throw new ArgumentNullException(nameof(resultSelector)); + + return source.Select(Some) + .Lag(offset, default, (curr, lag) => resultSelector(curr.Value, lag is (true, var some) ? some : default)); } /// diff --git a/MoreLinq/Lead.cs b/MoreLinq/Lead.cs index ba627088c..25c7637c7 100644 --- a/MoreLinq/Lead.cs +++ b/MoreLinq/Lead.cs @@ -19,6 +19,7 @@ namespace MoreLinq { using System; using System.Collections.Generic; + using System.Linq; public static partial class MoreEnumerable { @@ -37,9 +38,13 @@ public static partial class MoreEnumerable /// A projection function which accepts the current and subsequent (lead) element (in that order) and produces a result /// A sequence produced by projecting each element of the sequence with its lead pairing - public static IEnumerable Lead(this IEnumerable source, int offset, Func resultSelector) + public static IEnumerable Lead(this IEnumerable source, int offset, Func resultSelector) { - return Lead(source, offset, default!, resultSelector); + if (source is null) throw new ArgumentNullException(nameof(source)); + if (resultSelector is null) throw new ArgumentNullException(nameof(resultSelector)); + + return source.Select(Some) + .Lead(offset, default, (curr, lead) => resultSelector(curr.Value, lead is (true, var some) ? some : default)); } /// diff --git a/MoreLinq/MoreEnumerable.cs b/MoreLinq/MoreEnumerable.cs index 7d0630dbb..ca7938cc7 100644 --- a/MoreLinq/MoreEnumerable.cs +++ b/MoreLinq/MoreEnumerable.cs @@ -53,5 +53,9 @@ static int CountUpTo(this IEnumerable source, int max) return count; } + + // See https://github.com/atifaziz/Optuple + + static (bool HasValue, T Value) Some(T value) => (true, value); } }