Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Await operator for use in async setup expressions #1008

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 42 additions & 8 deletions src/Moq/ActionObserver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;

using System.Threading.Tasks;
using Moq.Expressions.Visitors;
using Moq.Internals;
using Moq.Linq.Expressions;
using Moq.Properties;

using TypeNameFormatter;
Expand All @@ -28,10 +29,11 @@ internal sealed class ActionObserver : ExpressionReconstructor
{
public override Expression<Action<T>> ReconstructExpression<T>(Action<T> action, object[] ctorArgs = null)
{
using (var awaitOperatorObserver = AwaitOperatorObserver.Activate())
using (var matcherObserver = MatcherObserver.Activate())
{
// Create the root recording proxy:
var root = (T)CreateProxy(typeof(T), ctorArgs, matcherObserver, out var rootRecorder);
var root = (T)CreateProxy(typeof(T), ctorArgs, awaitOperatorObserver, matcherObserver, out var rootRecorder);

Exception error = null;
try
Expand Down Expand Up @@ -61,6 +63,12 @@ public override Expression<Action<T>> ReconstructExpression<T>(Action<T> action,
if (invocation != null)
{
body = Expression.Call(body, invocation.Method, GetArgumentExpressions(invocation, recorder.Matches.ToArray()));
if (recorder.Await)
{
var genArg = invocation.Method.ReturnType.GetGenericArguments()[0];
var awaitMethod = typeof(AwaitOperator).GetMethods("Await").First().MakeGenericMethod(genArg);
body = Expression.Call(awaitMethod, body);
}
}
else
{
Expand Down Expand Up @@ -213,32 +221,37 @@ bool CanDistribute(int msi, int asi)
}

// Creates a proxy (way more light-weight than a `Mock<T>`!) with an invocation `Recorder` attached to it.
private static IProxy CreateProxy(Type type, object[] ctorArgs, MatcherObserver matcherObserver, out Recorder recorder)
private static IProxy CreateProxy(Type type, object[] ctorArgs, AwaitOperatorObserver awaitOperatorObserver, MatcherObserver matcherObserver, out Recorder recorder)
{
recorder = new Recorder(matcherObserver);
recorder = new Recorder(awaitOperatorObserver, matcherObserver);
return (IProxy)ProxyFactory.Instance.CreateProxy(type, recorder, Type.EmptyTypes, ctorArgs ?? new object[0]);
}

// Records an invocation, mocks return values, and builds a chain to the return value's recorder.
// This record represents the basis for reconstructing an expression tree.
private sealed class Recorder : IInterceptor
{
private readonly AwaitOperatorObserver awaitOperatorObserver;
private readonly MatcherObserver matcherObserver;
private int creationTimestamp;
private Invocation invocation;
private int invocationTimestamp;
private IProxy returnValue;
private object returnValue;

public Recorder(MatcherObserver matcherObserver)
public Recorder(AwaitOperatorObserver awaitOperatorObserver, MatcherObserver matcherObserver)
{
Debug.Assert(awaitOperatorObserver != null);
Debug.Assert(matcherObserver != null);

this.awaitOperatorObserver = awaitOperatorObserver;
this.matcherObserver = matcherObserver;
this.creationTimestamp = this.matcherObserver.GetNextTimestamp();
}

public Invocation Invocation => this.invocation;

public bool Await => this.awaitOperatorObserver.HasAwaitOperatorBetween(this.creationTimestamp, this.invocationTimestamp);

public IEnumerable<Match> Matches
{
get
Expand All @@ -248,7 +261,7 @@ public IEnumerable<Match> Matches
}
}

public Recorder Next => this.returnValue?.Interceptor as Recorder;
public Recorder Next => (Unwrap.ResultIfCompletedTask(this.returnValue) as IProxy)?.Interceptor as Recorder;

public void Intercept(Invocation invocation)
{
Expand Down Expand Up @@ -277,9 +290,14 @@ public void Intercept(Invocation invocation)
{
this.returnValue = null;
}
else if (IsTaskType(returnType, out var resultType))
{
var result = CreateProxy(resultType, null, this.awaitOperatorObserver, this.matcherObserver, out _);
this.returnValue = Wrap.GetResultWrapper(returnType).Invoke(result);
}
else if (returnType.IsMockable())
{
this.returnValue = CreateProxy(returnType, null, this.matcherObserver, out _);
this.returnValue = CreateProxy(returnType, null, this.awaitOperatorObserver, this.matcherObserver, out _);
}
else
{
Expand All @@ -296,6 +314,22 @@ public void Intercept(Invocation invocation)
invocation.Return();
}
}

private static bool IsTaskType(Type type, out Type resultType)
{
if (type.IsGenericType)
{
var typeDef = type.GetGenericTypeDefinition();
if (typeDef == typeof(Task<>) || typeDef == typeof(ValueTask<>))
{
resultType = type.GetGenericArguments()[0];
return true;
}
}

resultType = null;
return false;
}
}
}
}
94 changes: 94 additions & 0 deletions src/Moq/AwaitOperatorObserver.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
// Copyright (c) 2007, Clarius Consulting, Manas Technology Solutions, InSTEDD.
// All rights reserved. Licensed under the BSD 3-Clause License; see License.txt.

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;

namespace Moq
{
internal sealed class AwaitOperatorObserver : IDisposable
{
[ThreadStatic]
private static Stack<AwaitOperatorObserver> activations;

public static AwaitOperatorObserver Activate()
{
var activation = new AwaitOperatorObserver();

var activations = AwaitOperatorObserver.activations;
if (activations == null)
{
AwaitOperatorObserver.activations = activations = new Stack<AwaitOperatorObserver>();
}
activations.Push(activation);

return activation;
}

public static bool IsActive(out AwaitOperatorObserver observer)
{
var activations = AwaitOperatorObserver.activations;

if (activations != null && activations.Count > 0)
{
observer = activations.Peek();
return true;
}
else
{
observer = null;
return false;
}
}

private int timestamp;
private List<int> observations;

private AwaitOperatorObserver()
{
}

public void Dispose()
{
var activations = AwaitOperatorObserver.activations;
Debug.Assert(activations != null && activations.Count > 0);
activations.Pop();
}

/// <summary>
/// Returns the current timestamp. The next call will return a timestamp greater than this one,
/// allowing you to order invocations and matcher observations.
/// </summary>
public int GetNextTimestamp()
{
return ++this.timestamp;
}

/// <summary>
/// Adds the specified <see cref="Match"/> as an observation.
/// </summary>
public void OnAwaitOperator()
{
if (this.observations == null)
{
this.observations = new List<int>();
}

this.observations.Add(this.GetNextTimestamp());
}

public bool HasAwaitOperatorBetween(int fromTimestampInclusive, int toTimestampExclusive)
{
if (this.observations != null)
{
return this.observations.Any(o => fromTimestampInclusive <= o && o < toTimestampExclusive);
}
else
{
return false;
}
}
}
}
5 changes: 5 additions & 0 deletions src/Moq/ExpressionExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,11 @@ void Split(Expression e, out Expression r /* remainder */, out InvocationShape p
method,
arguments);
}
else if (methodCallExpression.Method.DeclaringType == typeof(Moq.Linq.Expressions.AwaitOperator))
{
Split(methodCallExpression.Arguments.Single(), out r, out p);
p.Await = true;
}
else
{
Debug.Assert(methodCallExpression.Method.IsExtensionMethod());
Expand Down
2 changes: 2 additions & 0 deletions src/Moq/InvocationShape.cs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ public InvocationShape(LambdaExpression expression, MethodInfo method, IReadOnly
this.exactGenericTypeArguments = exactGenericTypeArguments;
}

internal bool Await { get; set; }

public void Deconstruct(out LambdaExpression expression, out MethodInfo method, out IReadOnlyList<Expression> arguments)
{
expression = this.Expression;
Expand Down
45 changes: 45 additions & 0 deletions src/Moq/Linq/Expressions/AwaitOperator.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Copyright (c) 2007, Clarius Consulting, Manas Technology Solutions, InSTEDD.
// All rights reserved. Licensed under the BSD 3-Clause License; see License.txt.

using System.Threading.Tasks;

namespace Moq.Linq.Expressions
{
/// <todo/>
public static class AwaitOperator
{
/// <todo/>
public static TResult Await<TResult>(Task<TResult> task)
{
return Impl(task.Result);
}

/// <todo/>
public static TResult Await<TResult>(ValueTask<TResult> task)
{
return Impl(task.Result);
}

/// <todo/>
public static TResult Result<TResult>(this Task<TResult> task)
{
return Impl(task.Result);
}

/// <todo/>
public static TResult Result<TResult>(this ValueTask<TResult> task)
{
return Impl(task.Result);
}

private static TResult Impl<TResult>(TResult result)
{
if (AwaitOperatorObserver.IsActive(out var o))
{
o.OnAwaitOperator();
}

return result;
}
}
}
12 changes: 2 additions & 10 deletions src/Moq/LookupOrFallbackDefaultValueProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -151,22 +151,14 @@ private object CreateTaskOf(Type type, Mock mock)
{
var resultType = type.GetGenericArguments()[0];
var result = this.GetDefaultValue(resultType, mock);

var tcsType = typeof(TaskCompletionSource<>).MakeGenericType(resultType);
var tcs = Activator.CreateInstance(tcsType);
tcsType.GetMethod("SetResult").Invoke(tcs, new[] { result });
return tcsType.GetProperty("Task").GetValue(tcs, null);
return Wrap.AsTask(type, result);
}

private object CreateValueTaskOf(Type type, Mock mock)
{
var resultType = type.GetGenericArguments()[0];
var result = this.GetDefaultValue(resultType, mock);

// `Activator.CreateInstance` could throw an `AmbiguousMatchException` in this use case,
// so we're explicitly selecting and calling the constructor we want to use:
var valueTaskCtor = type.GetConstructor(new[] { resultType });
return valueTaskCtor.Invoke(new object[] { result });
return Wrap.AsValueTask(type, result);
}

private object CreateValueTupleOf(Type type, Mock mock)
Expand Down
Loading