Skip to content

Commit

Permalink
Fix "Lag" & "Lead" signatures with more honest nullable annotations
Browse files Browse the repository at this point in the history
This is a squashed merge of PR #805 that adds to #803.
  • Loading branch information
atifaziz authored Mar 22, 2021
1 parent ed0cf71 commit 54ccb69
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 6 deletions.
27 changes: 27 additions & 0 deletions MoreLinq.Test/LagTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
// limitations under the License.
#endregion

#nullable enable

namespace MoreLinq.Test
{
using NUnit.Framework;
Expand Down Expand Up @@ -131,5 +133,30 @@ public void TestLagPassesCorrectLagValuesOffsetBy2()
Assert.IsTrue(result.Skip(2).All(x => x.B == (x.A - 2)));
Assert.IsTrue(result.Take(2).All(x => (x.A - x.B) == x.A));
}

[Test]
public void TestLagWithNullableReferences()
{
var words = new[] { "foo", "bar", "baz", "qux" };
var result = words.Lag(2, (a, b) => new { A = a, B = b });
result.AssertSequenceEqual(
new { A = "foo", B = (string?)null },
new { A = "bar", B = (string?)null },
new { A = "baz", B = (string?)"foo" },
new { A = "qux", B = (string?)"bar" });
}

[Test]
public void TestLagWithNonNullableReferences()
{
var words = new[] { "foo", "bar", "baz", "qux" };
var empty = string.Empty;
var result = words.Lag(2, empty, (a, b) => new { A = a, B = b });
result.AssertSequenceEqual(
new { A = "foo", B = empty },
new { A = "bar", B = empty },
new { A = "baz", B = "foo" },
new { A = "qux", B = "bar" });
}
}
}
27 changes: 27 additions & 0 deletions MoreLinq.Test/LeadTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
// limitations under the License.
#endregion

#nullable enable

namespace MoreLinq.Test
{
using NUnit.Framework;
Expand Down Expand Up @@ -133,5 +135,30 @@ public void TestLeadPassesCorrectValueOffsetBy2()
Assert.IsTrue(result.Take(count - 2).All(x => x.B == (x.A + 2)));
Assert.IsTrue(result.Skip(count - 2).All(x => x.B == leadDefault && (x.A == count || x.A == count - 1)));
}

[Test]
public void TestLagWithNullableReferences()
{
var words = new[] { "foo", "bar", "baz", "qux" };
var result = words.Lead(2, (a, b) => new { A = a, B = b });
result.AssertSequenceEqual(
new { A = "foo", B = (string?)"baz" },
new { A = "bar", B = (string?)"qux" },
new { A = "baz", B = (string?)null },
new { A = "qux", B = (string?)null });
}

[Test]
public void TestLagWithNonNullableReferences()
{
var words = new[] { "foo", "bar", "baz", "qux" };
var empty = string.Empty;
var result = words.Lead(2, empty, (a, b) => new { A = a, B = b });
result.AssertSequenceEqual(
new { A = "foo", B = "baz" },
new { A = "bar", B = "qux" },
new { A = "baz", B = empty },
new { A = "qux", B = empty });
}
}
}
4 changes: 2 additions & 2 deletions MoreLinq/Extensions.g.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3025,7 +3025,7 @@ public static partial class LagExtension
/// <param name="resultSelector">A projection function which accepts the current and lagged items (in that order) and returns a result</param>
/// <returns>A sequence produced by projecting each element of the sequence with its lagged pairing</returns>

public static IEnumerable<TResult> Lag<TSource, TResult>(this IEnumerable<TSource> source, int offset, Func<TSource, TSource, TResult> resultSelector)
public static IEnumerable<TResult> Lag<TSource, TResult>(this IEnumerable<TSource> source, int offset, Func<TSource, TSource?, TResult> resultSelector)
=> MoreEnumerable.Lag(source, offset, resultSelector);

/// <summary>
Expand Down Expand Up @@ -3114,7 +3114,7 @@ public static partial class LeadExtension
/// <param name="resultSelector">A projection function which accepts the current and subsequent (lead) element (in that order) and produces a result</param>
/// <returns>A sequence produced by projecting each element of the sequence with its lead pairing</returns>

public static IEnumerable<TResult> Lead<TSource, TResult>(this IEnumerable<TSource> source, int offset, Func<TSource, TSource, TResult> resultSelector)
public static IEnumerable<TResult> Lead<TSource, TResult>(this IEnumerable<TSource> source, int offset, Func<TSource, TSource?, TResult> resultSelector)
=> MoreEnumerable.Lead(source, offset, resultSelector);

/// <summary>
Expand Down
9 changes: 7 additions & 2 deletions MoreLinq/Lag.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ namespace MoreLinq
{
using System;
using System.Collections.Generic;
using System.Linq;

public static partial class MoreEnumerable
{
Expand All @@ -36,9 +37,13 @@ public static partial class MoreEnumerable
/// <param name="resultSelector">A projection function which accepts the current and lagged items (in that order) and returns a result</param>
/// <returns>A sequence produced by projecting each element of the sequence with its lagged pairing</returns>

public static IEnumerable<TResult> Lag<TSource, TResult>(this IEnumerable<TSource> source, int offset, Func<TSource, TSource, TResult> resultSelector)
public static IEnumerable<TResult> Lag<TSource, TResult>(this IEnumerable<TSource> source, int offset, Func<TSource, TSource?, TResult> resultSelector)
{
return Lag(source, offset, default!, resultSelector);
if (source == null) throw new ArgumentNullException(nameof(source));
if (resultSelector is null) throw new ArgumentNullException(nameof(resultSelector));

return source.Select(Some)
.Lag(offset, default, (curr, lag) => resultSelector(curr.Value, lag is (true, var some) ? some : default));
}

/// <summary>
Expand Down
9 changes: 7 additions & 2 deletions MoreLinq/Lead.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ namespace MoreLinq
{
using System;
using System.Collections.Generic;
using System.Linq;

public static partial class MoreEnumerable
{
Expand All @@ -37,9 +38,13 @@ public static partial class MoreEnumerable
/// <param name="resultSelector">A projection function which accepts the current and subsequent (lead) element (in that order) and produces a result</param>
/// <returns>A sequence produced by projecting each element of the sequence with its lead pairing</returns>

public static IEnumerable<TResult> Lead<TSource, TResult>(this IEnumerable<TSource> source, int offset, Func<TSource, TSource, TResult> resultSelector)
public static IEnumerable<TResult> Lead<TSource, TResult>(this IEnumerable<TSource> source, int offset, Func<TSource, TSource?, TResult> resultSelector)
{
return Lead(source, offset, default!, resultSelector);
if (source is null) throw new ArgumentNullException(nameof(source));
if (resultSelector is null) throw new ArgumentNullException(nameof(resultSelector));

return source.Select(Some)
.Lead(offset, default, (curr, lead) => resultSelector(curr.Value, lead is (true, var some) ? some : default));
}

/// <summary>
Expand Down
4 changes: 4 additions & 0 deletions MoreLinq/MoreEnumerable.cs
Original file line number Diff line number Diff line change
Expand Up @@ -53,5 +53,9 @@ static int CountUpTo<T>(this IEnumerable<T> source, int max)

return count;
}

// See https://github.com/atifaziz/Optuple

static (bool HasValue, T Value) Some<T>(T value) => (true, value);
}
}

0 comments on commit 54ccb69

Please sign in to comment.