diff --git a/MoreLinq.Test/LagTest.cs b/MoreLinq.Test/LagTest.cs
index dd864b9f9..331fae6f7 100644
--- a/MoreLinq.Test/LagTest.cs
+++ b/MoreLinq.Test/LagTest.cs
@@ -15,6 +15,8 @@
// limitations under the License.
#endregion
+#nullable enable
+
namespace MoreLinq.Test
{
using NUnit.Framework;
@@ -131,5 +133,30 @@ public void TestLagPassesCorrectLagValuesOffsetBy2()
Assert.IsTrue(result.Skip(2).All(x => x.B == (x.A - 2)));
Assert.IsTrue(result.Take(2).All(x => (x.A - x.B) == x.A));
}
+
+ [Test]
+ public void TestLagWithNullableReferences()
+ {
+ var words = new[] { "foo", "bar", "baz", "qux" };
+ var result = words.Lag(2, (a, b) => new { A = a, B = b });
+ result.AssertSequenceEqual(
+ new { A = "foo", B = (string?)null },
+ new { A = "bar", B = (string?)null },
+ new { A = "baz", B = (string?)"foo" },
+ new { A = "qux", B = (string?)"bar" });
+ }
+
+ [Test]
+ public void TestLagWithNonNullableReferences()
+ {
+ var words = new[] { "foo", "bar", "baz", "qux" };
+ var empty = string.Empty;
+ var result = words.Lag(2, empty, (a, b) => new { A = a, B = b });
+ result.AssertSequenceEqual(
+ new { A = "foo", B = empty },
+ new { A = "bar", B = empty },
+ new { A = "baz", B = "foo" },
+ new { A = "qux", B = "bar" });
+ }
}
}
diff --git a/MoreLinq.Test/LeadTest.cs b/MoreLinq.Test/LeadTest.cs
index 4953a3afd..e0d4c592d 100644
--- a/MoreLinq.Test/LeadTest.cs
+++ b/MoreLinq.Test/LeadTest.cs
@@ -15,6 +15,8 @@
// limitations under the License.
#endregion
+#nullable enable
+
namespace MoreLinq.Test
{
using NUnit.Framework;
@@ -133,5 +135,30 @@ public void TestLeadPassesCorrectValueOffsetBy2()
Assert.IsTrue(result.Take(count - 2).All(x => x.B == (x.A + 2)));
Assert.IsTrue(result.Skip(count - 2).All(x => x.B == leadDefault && (x.A == count || x.A == count - 1)));
}
+
+ [Test]
+ public void TestLagWithNullableReferences()
+ {
+ var words = new[] { "foo", "bar", "baz", "qux" };
+ var result = words.Lead(2, (a, b) => new { A = a, B = b });
+ result.AssertSequenceEqual(
+ new { A = "foo", B = (string?)"baz" },
+ new { A = "bar", B = (string?)"qux" },
+ new { A = "baz", B = (string?)null },
+ new { A = "qux", B = (string?)null });
+ }
+
+ [Test]
+ public void TestLagWithNonNullableReferences()
+ {
+ var words = new[] { "foo", "bar", "baz", "qux" };
+ var empty = string.Empty;
+ var result = words.Lead(2, empty, (a, b) => new { A = a, B = b });
+ result.AssertSequenceEqual(
+ new { A = "foo", B = "baz" },
+ new { A = "bar", B = "qux" },
+ new { A = "baz", B = empty },
+ new { A = "qux", B = empty });
+ }
}
}
diff --git a/MoreLinq/Extensions.g.cs b/MoreLinq/Extensions.g.cs
index 4f94d77d3..0b4fa5e5c 100644
--- a/MoreLinq/Extensions.g.cs
+++ b/MoreLinq/Extensions.g.cs
@@ -3025,7 +3025,7 @@ public static partial class LagExtension
/// A projection function which accepts the current and lagged items (in that order) and returns a result
/// A sequence produced by projecting each element of the sequence with its lagged pairing
- public static IEnumerable Lag(this IEnumerable source, int offset, Func resultSelector)
+ public static IEnumerable Lag(this IEnumerable source, int offset, Func resultSelector)
=> MoreEnumerable.Lag(source, offset, resultSelector);
///
@@ -3114,7 +3114,7 @@ public static partial class LeadExtension
/// A projection function which accepts the current and subsequent (lead) element (in that order) and produces a result
/// A sequence produced by projecting each element of the sequence with its lead pairing
- public static IEnumerable Lead(this IEnumerable source, int offset, Func resultSelector)
+ public static IEnumerable Lead(this IEnumerable source, int offset, Func resultSelector)
=> MoreEnumerable.Lead(source, offset, resultSelector);
///
diff --git a/MoreLinq/Lag.cs b/MoreLinq/Lag.cs
index 533bbeac7..66bd4a1e7 100644
--- a/MoreLinq/Lag.cs
+++ b/MoreLinq/Lag.cs
@@ -19,6 +19,7 @@ namespace MoreLinq
{
using System;
using System.Collections.Generic;
+ using System.Linq;
public static partial class MoreEnumerable
{
@@ -36,9 +37,13 @@ public static partial class MoreEnumerable
/// A projection function which accepts the current and lagged items (in that order) and returns a result
/// A sequence produced by projecting each element of the sequence with its lagged pairing
- public static IEnumerable Lag(this IEnumerable source, int offset, Func resultSelector)
+ public static IEnumerable Lag(this IEnumerable source, int offset, Func resultSelector)
{
- return Lag(source, offset, default!, resultSelector);
+ if (source == null) throw new ArgumentNullException(nameof(source));
+ if (resultSelector is null) throw new ArgumentNullException(nameof(resultSelector));
+
+ return source.Select(Some)
+ .Lag(offset, default, (curr, lag) => resultSelector(curr.Value, lag is (true, var some) ? some : default));
}
///
diff --git a/MoreLinq/Lead.cs b/MoreLinq/Lead.cs
index ba627088c..25c7637c7 100644
--- a/MoreLinq/Lead.cs
+++ b/MoreLinq/Lead.cs
@@ -19,6 +19,7 @@ namespace MoreLinq
{
using System;
using System.Collections.Generic;
+ using System.Linq;
public static partial class MoreEnumerable
{
@@ -37,9 +38,13 @@ public static partial class MoreEnumerable
/// A projection function which accepts the current and subsequent (lead) element (in that order) and produces a result
/// A sequence produced by projecting each element of the sequence with its lead pairing
- public static IEnumerable Lead(this IEnumerable source, int offset, Func resultSelector)
+ public static IEnumerable Lead(this IEnumerable source, int offset, Func resultSelector)
{
- return Lead(source, offset, default!, resultSelector);
+ if (source is null) throw new ArgumentNullException(nameof(source));
+ if (resultSelector is null) throw new ArgumentNullException(nameof(resultSelector));
+
+ return source.Select(Some)
+ .Lead(offset, default, (curr, lead) => resultSelector(curr.Value, lead is (true, var some) ? some : default));
}
///
diff --git a/MoreLinq/MoreEnumerable.cs b/MoreLinq/MoreEnumerable.cs
index 7d0630dbb..ca7938cc7 100644
--- a/MoreLinq/MoreEnumerable.cs
+++ b/MoreLinq/MoreEnumerable.cs
@@ -53,5 +53,9 @@ static int CountUpTo(this IEnumerable source, int max)
return count;
}
+
+ // See https://github.com/atifaziz/Optuple
+
+ static (bool HasValue, T Value) Some(T value) => (true, value);
}
}