Skip to content

Commit

Permalink
Merge pull request #1126 from stakx/setup-task-result
Browse files Browse the repository at this point in the history
Add ability to set up the `.Result` of (value) tasks
  • Loading branch information
stakx authored Jan 1, 2021
2 parents bab305e + 6f6a89d commit f48c0f4
Show file tree
Hide file tree
Showing 15 changed files with 604 additions and 4 deletions.
32 changes: 32 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,38 @@ The format is loosely based on [Keep a Changelog](http://keepachangelog.com/en/1

## Unreleased

#### Added

* Ability to directly set up the `.Result` of tasks and value tasks, which makes setup expressions more uniform by rendering dedicated async verbs like `.ReturnsAsync`, `.ThrowsAsync`, etc. unnecessary:

```diff
-mock.Setup(x => x.GetFooAsync()).ReturnsAsync(foo)
+mock.Setup(x => x.GetFooAsync().Result).Returns(foo)
```

This is useful in places where there currently aren't any such async verbs at all:

```diff
-Mock.Of<X>(x => x.GetFooAsync() == Task.FromResult(foo))
+Mock.Of<X>(x => x.GetFooAsync().Result == foo)
```

This also allows recursive setups / method chaining across async calls inside a single setup expression:

```diff
-mock.Setup(x => x.GetFooAsync()).ReturnsAsync(Mock.Of<IFoo>(f => f.Bar == bar))
+mock.Setup(x => x.GetFooAsync().Result.Bar).Returns(bar)
```

or, with only `Mock.Of`:

```diff
-Mock.Of<X>(x => x.GetFooAsync() == Task.FromResult(Mock.Of<IFoo>(f => f.Bar == bar)))
+Mock.Of<X>(x => x.GetFooAsync().Result.Bar == bar)
```

This should work in all principal setup methods (`Mock.Of`, `mock.Setup…`, `mock.Verify…`). Support in `mock.Protected()` and for custom awaitable types may be added in the future. (@stakx, #1125)

#### Changed

* Attempts to mark conditionals setup as verifiable are once again allowed; it turns out that forbidding it (as was done in #997 for version 4.14.0) is in fact a regression. (@stakx, #1121)
Expand Down
23 changes: 21 additions & 2 deletions src/Moq/ActionObserver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using System.Linq.Expressions;
using System.Reflection;

using Moq.Async;
using Moq.Expressions.Visitors;
using Moq.Internals;
using Moq.Properties;
Expand Down Expand Up @@ -60,6 +61,19 @@ public override Expression<Action<T>> ReconstructExpression<T>(Action<T> action,
var invocation = recorder.Invocation;
if (invocation != null)
{
var resultType = invocation.Method.DeclaringType;
if (resultType.IsAssignableFrom(body.Type) == false)
{
if (AwaitableFactory.TryGet(body.Type) is { } awaitableHandler
&& awaitableHandler.ResultType.IsAssignableFrom(resultType))
{
// We are here because the current invocation cannot be chained onto the previous one,
// however it *can* be chained if we assume that there was a `.Result` query on the
// former invocation that we don't see because non-virtual members aren't recorded.
// In this case, we make things work by adding back the missing `.Result`:
body = awaitableHandler.CreateResultExpression(body);
}
}
body = Expression.Call(body, invocation.Method, GetArgumentExpressions(invocation, recorder.Matches.ToArray()));
}
else
Expand Down Expand Up @@ -227,7 +241,7 @@ private sealed class Recorder : IInterceptor
private int creationTimestamp;
private Invocation invocation;
private int invocationTimestamp;
private IProxy returnValue;
private object returnValue;

public Recorder(MatcherObserver matcherObserver)
{
Expand All @@ -248,7 +262,7 @@ public IEnumerable<Match> Matches
}
}

public Recorder Next => this.returnValue?.Interceptor as Recorder;
public Recorder Next => (Awaitable.TryGetResultRecursive(this.returnValue) as IProxy)?.Interceptor as Recorder;

public void Intercept(Invocation invocation)
{
Expand Down Expand Up @@ -277,6 +291,11 @@ public void Intercept(Invocation invocation)
{
this.returnValue = null;
}
else if (AwaitableFactory.TryGet(returnType) is { } awaitableFactory)
{
var result = CreateProxy(awaitableFactory.ResultType, null, this.matcherObserver, out _);
this.returnValue = awaitableFactory.CreateCompleted(result);
}
else if (returnType.IsMockable())
{
this.returnValue = CreateProxy(returnType, null, this.matcherObserver, out _);
Expand Down
40 changes: 40 additions & 0 deletions src/Moq/Async/AwaitExpression.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Copyright (c) 2007, Clarius Consulting, Manas Technology Solutions, InSTEDD, and Contributors.
// All rights reserved. Licensed under the BSD 3-Clause License; see License.txt.

using System;
using System.Diagnostics;
using System.Linq.Expressions;

namespace Moq.Async
{
internal sealed class AwaitExpression : Expression
{
private readonly IAwaitableFactory awaitableFactory;
private readonly Expression operand;

public AwaitExpression(Expression operand, IAwaitableFactory awaitableFactory)
{
Debug.Assert(awaitableFactory != null);
Debug.Assert(operand != null);

this.awaitableFactory = awaitableFactory;
this.operand = operand;
}

public override bool CanReduce => false;

public override ExpressionType NodeType => ExpressionType.Extension;

public Expression Operand => this.operand;

public override Type Type => this.awaitableFactory.ResultType;

public override string ToString()
{
return this.awaitableFactory.ResultType == typeof(void) ? $"await {this.operand}"
: $"(await {this.operand})";
}

protected override Expression VisitChildren(ExpressionVisitor visitor) => this;
}
}
27 changes: 27 additions & 0 deletions src/Moq/Async/AwaitableFactory`1.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
// All rights reserved. Licensed under the BSD 3-Clause License; see License.txt.

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Linq.Expressions;

namespace Moq.Async
{
Expand All @@ -23,6 +26,30 @@ object IAwaitableFactory.CreateCompleted(object result)
return this.CreateCompleted();
}

public abstract TAwaitable CreateFaulted(Exception exception);

object IAwaitableFactory.CreateFaulted(Exception exception)
{
Debug.Assert(exception != null);

return this.CreateFaulted(exception);
}

public abstract TAwaitable CreateFaulted(IEnumerable<Exception> exceptions);

object IAwaitableFactory.CreateFaulted(IEnumerable<Exception> exceptions)
{
Debug.Assert(exceptions != null);
Debug.Assert(exceptions.Any());

return this.CreateFaulted(exceptions);
}

Expression IAwaitableFactory.CreateResultExpression(Expression awaitableExpression)
{
return new AwaitExpression(awaitableExpression, this);
}

bool IAwaitableFactory.TryGetResult(object awaitable, out object result)
{
Debug.Assert(awaitable is TAwaitable);
Expand Down
24 changes: 24 additions & 0 deletions src/Moq/Async/AwaitableFactory`2.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
// All rights reserved. Licensed under the BSD 3-Clause License; see License.txt.

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Linq.Expressions;

namespace Moq.Async
{
Expand All @@ -23,8 +26,29 @@ object IAwaitableFactory.CreateCompleted(object result)
return this.CreateCompleted((TResult)result);
}

public abstract TAwaitable CreateFaulted(Exception exception);

object IAwaitableFactory.CreateFaulted(Exception exception)
{
Debug.Assert(exception != null);

return this.CreateFaulted(exception);
}

public abstract TAwaitable CreateFaulted(IEnumerable<Exception> exceptions);

object IAwaitableFactory.CreateFaulted(IEnumerable<Exception> exceptions)
{
Debug.Assert(exceptions != null);
Debug.Assert(exceptions.Any());

return this.CreateFaulted(exceptions);
}

public abstract bool TryGetResult(TAwaitable awaitable, out TResult result);

public abstract Expression CreateResultExpression(Expression awaitableExpression);

bool IAwaitableFactory.TryGetResult(object awaitable, out object result)
{
Debug.Assert(awaitable is TAwaitable);
Expand Down
8 changes: 8 additions & 0 deletions src/Moq/Async/IAwaitableFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
// All rights reserved. Licensed under the BSD 3-Clause License; see License.txt.

using System;
using System.Collections.Generic;
using System.Linq.Expressions;

namespace Moq.Async
{
Expand All @@ -11,6 +13,12 @@ internal interface IAwaitableFactory

object CreateCompleted(object result = null);

object CreateFaulted(Exception exception);

object CreateFaulted(IEnumerable<Exception> exceptions);

Expression CreateResultExpression(Expression awaitableExpression);

bool TryGetResult(object awaitable, out object result);
}
}
18 changes: 18 additions & 0 deletions src/Moq/Async/TaskFactory.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
// Copyright (c) 2007, Clarius Consulting, Manas Technology Solutions, InSTEDD, and Contributors.
// All rights reserved. Licensed under the BSD 3-Clause License; see License.txt.

using System;
using System.Collections.Generic;
using System.Linq.Expressions;
using System.Reflection;
using System.Threading.Tasks;

namespace Moq.Async
Expand All @@ -17,5 +21,19 @@ public override Task CreateCompleted()
{
return Task.FromResult<object>(default);
}

public override Task CreateFaulted(Exception exception)
{
var tcs = new TaskCompletionSource<object>();
tcs.SetException(exception);
return tcs.Task;
}

public override Task CreateFaulted(IEnumerable<Exception> exceptions)
{
var tcs = new TaskCompletionSource<object>();
tcs.SetException(exceptions);
return tcs.Task;
}
}
}
24 changes: 24 additions & 0 deletions src/Moq/Async/TaskFactory`1.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
// Copyright (c) 2007, Clarius Consulting, Manas Technology Solutions, InSTEDD, and Contributors.
// All rights reserved. Licensed under the BSD 3-Clause License; see License.txt.

using System;
using System.Collections.Generic;
using System.Linq.Expressions;
using System.Threading.Tasks;

namespace Moq.Async
Expand All @@ -12,6 +15,27 @@ public override Task<TResult> CreateCompleted(TResult result)
return Task.FromResult(result);
}

public override Task<TResult> CreateFaulted(Exception exception)
{
var tcs = new TaskCompletionSource<TResult>();
tcs.SetException(exception);
return tcs.Task;
}

public override Task<TResult> CreateFaulted(IEnumerable<Exception> exceptions)
{
var tcs = new TaskCompletionSource<TResult>();
tcs.SetException(exceptions);
return tcs.Task;
}

public override Expression CreateResultExpression(Expression awaitableExpression)
{
return Expression.MakeMemberAccess(
awaitableExpression,
typeof(Task<TResult>).GetProperty(nameof(Task<TResult>.Result)));
}

public override bool TryGetResult(Task<TResult> task, out TResult result)
{
if (task.Status == TaskStatus.RanToCompletion)
Expand Down
16 changes: 16 additions & 0 deletions src/Moq/Async/ValueTaskFactory.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Copyright (c) 2007, Clarius Consulting, Manas Technology Solutions, InSTEDD, and Contributors.
// All rights reserved. Licensed under the BSD 3-Clause License; see License.txt.

using System;
using System.Collections.Generic;
using System.Threading.Tasks;

namespace Moq.Async
Expand All @@ -17,5 +19,19 @@ public override ValueTask CreateCompleted()
{
return default;
}

public override ValueTask CreateFaulted(Exception exception)
{
var tcs = new TaskCompletionSource<object>();
tcs.SetException(exception);
return new ValueTask(tcs.Task);
}

public override ValueTask CreateFaulted(IEnumerable<Exception> exceptions)
{
var tcs = new TaskCompletionSource<object>();
tcs.SetException(exceptions);
return new ValueTask(tcs.Task);
}
}
}
24 changes: 24 additions & 0 deletions src/Moq/Async/ValueTaskFactory`1.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
// Copyright (c) 2007, Clarius Consulting, Manas Technology Solutions, InSTEDD, and Contributors.
// All rights reserved. Licensed under the BSD 3-Clause License; see License.txt.

using System;
using System.Collections.Generic;
using System.Linq.Expressions;
using System.Threading.Tasks;

namespace Moq.Async
Expand All @@ -12,6 +15,27 @@ public override ValueTask<TResult> CreateCompleted(TResult result)
return new ValueTask<TResult>(result);
}

public override ValueTask<TResult> CreateFaulted(Exception exception)
{
var tcs = new TaskCompletionSource<TResult>();
tcs.SetException(exception);
return new ValueTask<TResult>(tcs.Task);
}

public override ValueTask<TResult> CreateFaulted(IEnumerable<Exception> exceptions)
{
var tcs = new TaskCompletionSource<TResult>();
tcs.SetException(exceptions);
return new ValueTask<TResult>(tcs.Task);
}

public override Expression CreateResultExpression(Expression awaitableExpression)
{
return Expression.MakeMemberAccess(
awaitableExpression,
typeof(ValueTask<TResult>).GetProperty(nameof(ValueTask<TResult>.Result)));
}

public override bool TryGetResult(ValueTask<TResult> valueTask, out TResult result)
{
if (valueTask.IsCompletedSuccessfully)
Expand Down
Loading

0 comments on commit f48c0f4

Please sign in to comment.