Skip to content

Commit

Permalink
Add Await operator for use in async setup expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
stakx committed Apr 25, 2020
1 parent 14f2157 commit f60fcb4
Show file tree
Hide file tree
Showing 6 changed files with 184 additions and 22 deletions.
5 changes: 5 additions & 0 deletions src/Moq/ExpressionExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,11 @@ void Split(Expression e, out Expression r /* remainder */, out InvocationShape p
method,
arguments);
}
else if (methodCallExpression.Method.DeclaringType == typeof(Moq.Linq.Expressions.AwaitOperator))
{
Split(methodCallExpression.Arguments.Single(), out r, out p);
p.Await = true;
}
else
{
Debug.Assert(methodCallExpression.Method.IsExtensionMethod());
Expand Down
2 changes: 2 additions & 0 deletions src/Moq/InvocationShape.cs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ public InvocationShape(LambdaExpression expression, MethodInfo method, IReadOnly
this.exactGenericTypeArguments = exactGenericTypeArguments;
}

internal bool Await { get; set; }

public void Deconstruct(out LambdaExpression expression, out MethodInfo method, out IReadOnlyList<Expression> arguments)
{
expression = this.Expression;
Expand Down
35 changes: 35 additions & 0 deletions src/Moq/Linq/Expressions/AwaitOperator.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright (c) 2007, Clarius Consulting, Manas Technology Solutions, InSTEDD.
// All rights reserved. Licensed under the BSD 3-Clause License; see License.txt.

using System.Threading.Tasks;

namespace Moq.Linq.Expressions
{
/// <todo/>
public static class AwaitOperator
{
/// <todo/>
public static TResult Await<TResult>(Task<TResult> task)
{
return task.Result;
}

/// <todo/>
public static TResult Await<TResult>(ValueTask<TResult> task)
{
return task.Result;
}

/// <todo/>
public static TResult Result<TResult>(this Task<TResult> task)
{
return task.Result;
}

/// <todo/>
public static TResult Result<TResult>(this ValueTask<TResult> task)
{
return task.Result;
}
}
}
12 changes: 2 additions & 10 deletions src/Moq/LookupOrFallbackDefaultValueProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -151,22 +151,14 @@ private object CreateTaskOf(Type type, Mock mock)
{
var resultType = type.GetGenericArguments()[0];
var result = this.GetDefaultValue(resultType, mock);

var tcsType = typeof(TaskCompletionSource<>).MakeGenericType(resultType);
var tcs = Activator.CreateInstance(tcsType);
tcsType.GetMethod("SetResult").Invoke(tcs, new[] { result });
return tcsType.GetProperty("Task").GetValue(tcs, null);
return Wrap.AsTask(type, result);
}

private object CreateValueTaskOf(Type type, Mock mock)
{
var resultType = type.GetGenericArguments()[0];
var result = this.GetDefaultValue(resultType, mock);

// `Activator.CreateInstance` could throw an `AmbiguousMatchException` in this use case,
// so we're explicitly selecting and calling the constructor we want to use:
var valueTaskCtor = type.GetConstructor(new[] { resultType });
return valueTaskCtor.Invoke(new object[] { result });
return Wrap.AsValueTask(type, result);
}

private object CreateValueTupleOf(Type type, Mock mock)
Expand Down
101 changes: 89 additions & 12 deletions src/Moq/MethodCall.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using System.Linq.Expressions;
using System.Reflection;
using System.Text;
using System.Threading.Tasks;

using Moq.Properties;

Expand Down Expand Up @@ -223,12 +224,64 @@ public void SetRaiseEventResponse<TMock>(Action<TMock> eventExpression, params o
this.raiseEventResponse = new RaiseEventResponse(this.Mock, expression, null, args);
}

private Func<object, object> MakeWrap()
{
if (this.Expectation.Await)
{
var returnType = this.Method.ReturnType;
Debug.Assert(returnType.IsGenericType);
var genericTypeDef = returnType.GetGenericTypeDefinition();
if (genericTypeDef == typeof(Task<>))
{
return value => Wrap.AsTask(returnType, value);
}
else if (genericTypeDef == typeof(ValueTask<>))
{
return value => Wrap.AsValueTask(returnType, value);
}
else
{
throw new NotSupportedException();
}
}
else
{
return null;
}
}

private Func<Exception, object> MakeExceptionWrap()
{
if (this.Expectation.Await)
{
var returnType = this.Method.ReturnType;
Debug.Assert(returnType.IsGenericType);
var genericTypeDef = returnType.GetGenericTypeDefinition();
if (genericTypeDef == typeof(Task<>))
{
return exception => Wrap.AsFaultedTask(returnType, exception);
}
else if (genericTypeDef == typeof(ValueTask<>))
{
return exception => Wrap.AsFaultedValueTask(returnType, exception);
}
else
{
throw new NotSupportedException();
}
}
else
{
return null;
}
}

public void SetEagerReturnsResponse(object value)
{
Debug.Assert((this.flags & Flags.MethodIsNonVoid) != 0);
Debug.Assert(this.returnOrThrowResponse == null);

this.returnOrThrowResponse = new ReturnEagerValueResponse(value);
this.returnOrThrowResponse = new ReturnEagerValueResponse(value, this.MakeWrap());
}

public void SetReturnsResponse(Delegate valueFactory)
Expand All @@ -244,15 +297,15 @@ public void SetReturnsResponse(Delegate valueFactory)
// and instead of in `Returns(TResult)`, we ended up in `Returns(Delegate)` or `Returns(Func)`,
// which likely isn't what the user intended.
// So here we do what we would've done in `Returns(TResult)`:
this.returnOrThrowResponse = new ReturnEagerValueResponse(this.Method.ReturnType.GetDefaultValue());
this.returnOrThrowResponse = new ReturnEagerValueResponse(this.Method.ReturnType.GetDefaultValue(), this.MakeWrap());
}
else if (this.Method.ReturnType == typeof(Delegate))
{
// If `TResult` is `Delegate`, that is someone is setting up the return value of a method
// that returns a `Delegate`, then we have arrived here because C# picked the wrong overload:
// We don't want to invoke the passed delegate to get a return value; the passed delegate
// already is the return value.
this.returnOrThrowResponse = new ReturnEagerValueResponse(valueFactory);
this.returnOrThrowResponse = new ReturnEagerValueResponse(valueFactory, this.MakeWrap());
}
else if (IsInvocationFunc(valueFactory))
{
Expand All @@ -261,7 +314,7 @@ public void SetReturnsResponse(Delegate valueFactory)
else
{
ValidateCallback(valueFactory);
this.returnOrThrowResponse = new ReturnLazyValueResponse(valueFactory);
this.returnOrThrowResponse = new ReturnLazyValueResponse(valueFactory, this.MakeWrap());
}

bool IsInvocationFunc(Delegate callback)
Expand Down Expand Up @@ -317,6 +370,12 @@ void ValidateCallback(Delegate callback)

var expectedReturnType = this.Method.ReturnType;

if (this.Expectation.Await)
{
Debug.Assert(this.Method.ReturnType.IsGenericType);
expectedReturnType = this.Method.ReturnType.GetGenericArguments()[0];
}

if (!expectedReturnType.IsAssignableFrom(actualReturnType))
{
// TODO: If the return type is a matcher, does the callback's return type need to be matched against it?
Expand All @@ -335,7 +394,7 @@ void ValidateCallback(Delegate callback)

public void SetThrowExceptionResponse(Exception exception)
{
this.returnOrThrowResponse = new ThrowExceptionResponse(exception);
this.returnOrThrowResponse = new ThrowExceptionResponse(exception, this.MakeExceptionWrap());
}

protected override void ResetCore()
Expand Down Expand Up @@ -470,15 +529,19 @@ public override void RespondTo(Invocation invocation)
private sealed class ReturnEagerValueResponse : Response
{
public readonly object Value;
private readonly Func<object, object> wrap;

public ReturnEagerValueResponse(object value)
public ReturnEagerValueResponse(object value, Func<object, object> wrap)
{
this.Value = value;
this.wrap = wrap;
}

public override void RespondTo(Invocation invocation)
{
invocation.Return(this.Value);
var value = this.Value;
if (this.wrap != null) value = this.wrap(value);
invocation.Return(value);
}
}

Expand All @@ -502,32 +565,46 @@ public override void RespondTo(Invocation invocation)
private sealed class ReturnLazyValueResponse : Response
{
private readonly Delegate valueFactory;
private readonly Func<object, object> wrap;

public ReturnLazyValueResponse(Delegate valueFactory)
public ReturnLazyValueResponse(Delegate valueFactory, Func<object, object> wrap)
{
this.valueFactory = valueFactory;
this.wrap = wrap;
}

public override void RespondTo(Invocation invocation)
{
invocation.Return(this.valueFactory.CompareParameterTypesTo(Type.EmptyTypes)
var value = this.valueFactory.CompareParameterTypesTo(Type.EmptyTypes)
? valueFactory.InvokePreserveStack() //we need this, for the user to be able to use parameterless methods
: valueFactory.InvokePreserveStack(invocation.Arguments)); //will throw if parameters mismatch
: valueFactory.InvokePreserveStack(invocation.Arguments); //will throw if parameters mismatch
if (this.wrap != null) value = this.wrap(value);
invocation.Return(value);
}
}

private sealed class ThrowExceptionResponse : Response
{
private readonly Exception exception;
private readonly Func<Exception, object> wrap;

public ThrowExceptionResponse(Exception exception)
public ThrowExceptionResponse(Exception exception, Func<Exception, object> wrap)
{
this.exception = exception;
this.wrap = wrap;
}

public override void RespondTo(Invocation invocation)
{
throw this.exception;
if (this.wrap != null)
{
var value = this.wrap(exception);
invocation.Return(value);
}
else
{
throw this.exception;
}
}
}

Expand Down
51 changes: 51 additions & 0 deletions src/Moq/Wrap.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright (c) 2007, Clarius Consulting, Manas Technology Solutions, InSTEDD.
// All rights reserved. Licensed under the BSD 3-Clause License; see License.txt.

using System;
using System.Threading.Tasks;

namespace Moq
{
internal static class Wrap
{
public static object AsFaultedTask(Type type, Exception exception)
{
var resultType = type.GetGenericArguments()[0];

var tcsType = typeof(TaskCompletionSource<>).MakeGenericType(resultType);
var tcs = Activator.CreateInstance(tcsType);
tcsType.GetMethod("SetException", new Type[] { typeof(Exception) }).Invoke(tcs, new[] { exception });
return tcsType.GetProperty("Task").GetValue(tcs, null);
}

public static object AsTask(Type type, object result)
{
var resultType = type.GetGenericArguments()[0];

var tcsType = typeof(TaskCompletionSource<>).MakeGenericType(resultType);
var tcs = Activator.CreateInstance(tcsType);
tcsType.GetMethod("SetResult").Invoke(tcs, new[] { result });
return tcsType.GetProperty("Task").GetValue(tcs, null);
}

public static object AsFaultedValueTask(Type type, Exception exception)
{
var resultType = type.GetGenericArguments()[0];

// `Activator.CreateInstance` could throw an `AmbiguousMatchException` in this use case,
// so we're explicitly selecting and calling the constructor we want to use:
var valueTaskCtor = type.GetConstructor(new[] { typeof(Task<>).MakeGenericType(resultType) });
return valueTaskCtor.Invoke(new object[] { Wrap.AsFaultedTask(type, exception) });
}

public static object AsValueTask(Type type, object result)
{
var resultType = type.GetGenericArguments()[0];

// `Activator.CreateInstance` could throw an `AmbiguousMatchException` in this use case,
// so we're explicitly selecting and calling the constructor we want to use:
var valueTaskCtor = type.GetConstructor(new[] { resultType });
return valueTaskCtor.Invoke(new object[] { result });
}
}
}

0 comments on commit f60fcb4

Please sign in to comment.