Skip to content

Commit

Permalink
Support Await in SetupSet
Browse files Browse the repository at this point in the history
  • Loading branch information
stakx committed Apr 25, 2020
1 parent 91eabe2 commit 713f989
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 12 deletions.
50 changes: 42 additions & 8 deletions src/Moq/ActionObserver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;

using System.Threading.Tasks;
using Moq.Expressions.Visitors;
using Moq.Internals;
using Moq.Linq.Expressions;
using Moq.Properties;

using TypeNameFormatter;
Expand All @@ -28,10 +29,11 @@ internal sealed class ActionObserver : ExpressionReconstructor
{
public override Expression<Action<T>> ReconstructExpression<T>(Action<T> action, object[] ctorArgs = null)
{
using (var awaitOperatorObserver = AwaitOperatorObserver.Activate())
using (var matcherObserver = MatcherObserver.Activate())
{
// Create the root recording proxy:
var root = (T)CreateProxy(typeof(T), ctorArgs, matcherObserver, out var rootRecorder);
var root = (T)CreateProxy(typeof(T), ctorArgs, awaitOperatorObserver, matcherObserver, out var rootRecorder);

Exception error = null;
try
Expand Down Expand Up @@ -61,6 +63,12 @@ public override Expression<Action<T>> ReconstructExpression<T>(Action<T> action,
if (invocation != null)
{
body = Expression.Call(body, invocation.Method, GetArgumentExpressions(invocation, recorder.Matches.ToArray()));
if (recorder.Await)
{
var genArg = invocation.Method.ReturnType.GetGenericArguments()[0];
var awaitMethod = typeof(AwaitOperator).GetMethods("Await").First().MakeGenericMethod(genArg);
body = Expression.Call(awaitMethod, body);
}
}
else
{
Expand Down Expand Up @@ -213,32 +221,37 @@ bool CanDistribute(int msi, int asi)
}

// Creates a proxy (way more light-weight than a `Mock<T>`!) with an invocation `Recorder` attached to it.
private static IProxy CreateProxy(Type type, object[] ctorArgs, MatcherObserver matcherObserver, out Recorder recorder)
private static IProxy CreateProxy(Type type, object[] ctorArgs, AwaitOperatorObserver awaitOperatorObserver, MatcherObserver matcherObserver, out Recorder recorder)
{
recorder = new Recorder(matcherObserver);
recorder = new Recorder(awaitOperatorObserver, matcherObserver);
return (IProxy)ProxyFactory.Instance.CreateProxy(type, recorder, Type.EmptyTypes, ctorArgs ?? new object[0]);
}

// Records an invocation, mocks return values, and builds a chain to the return value's recorder.
// This record represents the basis for reconstructing an expression tree.
private sealed class Recorder : IInterceptor
{
private readonly AwaitOperatorObserver awaitOperatorObserver;
private readonly MatcherObserver matcherObserver;
private int creationTimestamp;
private Invocation invocation;
private int invocationTimestamp;
private IProxy returnValue;
private object returnValue;

public Recorder(MatcherObserver matcherObserver)
public Recorder(AwaitOperatorObserver awaitOperatorObserver, MatcherObserver matcherObserver)
{
Debug.Assert(awaitOperatorObserver != null);
Debug.Assert(matcherObserver != null);

this.awaitOperatorObserver = awaitOperatorObserver;
this.matcherObserver = matcherObserver;
this.creationTimestamp = this.matcherObserver.GetNextTimestamp();
}

public Invocation Invocation => this.invocation;

public bool Await => this.awaitOperatorObserver.HasAwaitOperatorBetween(this.creationTimestamp, this.invocationTimestamp);

public IEnumerable<Match> Matches
{
get
Expand All @@ -248,7 +261,7 @@ public IEnumerable<Match> Matches
}
}

public Recorder Next => this.returnValue?.Interceptor as Recorder;
public Recorder Next => (Unwrap.ResultIfCompletedTask(this.returnValue) as IProxy)?.Interceptor as Recorder;

public void Intercept(Invocation invocation)
{
Expand Down Expand Up @@ -277,9 +290,14 @@ public void Intercept(Invocation invocation)
{
this.returnValue = null;
}
else if (IsTaskType(returnType, out var resultType))
{
var result = CreateProxy(resultType, null, this.awaitOperatorObserver, this.matcherObserver, out _);
this.returnValue = Wrap.GetResultWrapper(returnType).Invoke(result);
}
else if (returnType.IsMockable())
{
this.returnValue = CreateProxy(returnType, null, this.matcherObserver, out _);
this.returnValue = CreateProxy(returnType, null, this.awaitOperatorObserver, this.matcherObserver, out _);
}
else
{
Expand All @@ -296,6 +314,22 @@ public void Intercept(Invocation invocation)
invocation.Return();
}
}

private static bool IsTaskType(Type type, out Type resultType)
{
if (type.IsGenericType)
{
var typeDef = type.GetGenericTypeDefinition();
if (typeDef == typeof(Task<>) || typeDef == typeof(ValueTask<>))
{
resultType = type.GetGenericArguments()[0];
return true;
}
}

resultType = null;
return false;
}
}
}
}
94 changes: 94 additions & 0 deletions src/Moq/AwaitOperatorObserver.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
// Copyright (c) 2007, Clarius Consulting, Manas Technology Solutions, InSTEDD.
// All rights reserved. Licensed under the BSD 3-Clause License; see License.txt.

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;

namespace Moq
{
internal sealed class AwaitOperatorObserver : IDisposable
{
[ThreadStatic]
private static Stack<AwaitOperatorObserver> activations;

public static AwaitOperatorObserver Activate()
{
var activation = new AwaitOperatorObserver();

var activations = AwaitOperatorObserver.activations;
if (activations == null)
{
AwaitOperatorObserver.activations = activations = new Stack<AwaitOperatorObserver>();
}
activations.Push(activation);

return activation;
}

public static bool IsActive(out AwaitOperatorObserver observer)
{
var activations = AwaitOperatorObserver.activations;

if (activations != null && activations.Count > 0)
{
observer = activations.Peek();
return true;
}
else
{
observer = null;
return false;
}
}

private int timestamp;
private List<int> observations;

private AwaitOperatorObserver()
{
}

public void Dispose()
{
var activations = AwaitOperatorObserver.activations;
Debug.Assert(activations != null && activations.Count > 0);
activations.Pop();
}

/// <summary>
/// Returns the current timestamp. The next call will return a timestamp greater than this one,
/// allowing you to order invocations and matcher observations.
/// </summary>
public int GetNextTimestamp()
{
return ++this.timestamp;
}

/// <summary>
/// Adds the specified <see cref="Match"/> as an observation.
/// </summary>
public void OnAwaitOperator()
{
if (this.observations == null)
{
this.observations = new List<int>();
}

this.observations.Add(this.GetNextTimestamp());
}

public bool HasAwaitOperatorBetween(int fromTimestampInclusive, int toTimestampExclusive)
{
if (this.observations != null)
{
return this.observations.Any(o => fromTimestampInclusive <= o && o < toTimestampExclusive);
}
else
{
return false;
}
}
}
}
18 changes: 14 additions & 4 deletions src/Moq/Linq/Expressions/AwaitOperator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,35 @@ public static class AwaitOperator
/// <todo/>
public static TResult Await<TResult>(Task<TResult> task)
{
return task.Result;
return Impl(task.Result);
}

/// <todo/>
public static TResult Await<TResult>(ValueTask<TResult> task)
{
return task.Result;
return Impl(task.Result);
}

/// <todo/>
public static TResult Result<TResult>(this Task<TResult> task)
{
return task.Result;
return Impl(task.Result);
}

/// <todo/>
public static TResult Result<TResult>(this ValueTask<TResult> task)
{
return task.Result;
return Impl(task.Result);
}

private static TResult Impl<TResult>(TResult result)
{
if (AwaitOperatorObserver.IsActive(out var o))
{
o.OnAwaitOperator();
}

return result;
}
}
}

0 comments on commit 713f989

Please sign in to comment.