From 42521c47313d86b4d06c617e1bdcda22046fdcaa Mon Sep 17 00:00:00 2001 From: stakx Date: Fri, 1 Jan 2021 11:27:50 +0100 Subject: [PATCH] Add ability in `IAwaitableFactory` to create result expression --- src/Moq/Async/AwaitExpression.cs | 40 +++++++++++++++++++++++++++++ src/Moq/Async/AwaitableFactory`1.cs | 6 +++++ src/Moq/Async/AwaitableFactory`2.cs | 3 +++ src/Moq/Async/IAwaitableFactory.cs | 3 +++ src/Moq/Async/TaskFactory.cs | 1 + src/Moq/Async/TaskFactory`1.cs | 8 ++++++ src/Moq/Async/ValueTaskFactory`1.cs | 8 ++++++ 7 files changed, 69 insertions(+) create mode 100644 src/Moq/Async/AwaitExpression.cs diff --git a/src/Moq/Async/AwaitExpression.cs b/src/Moq/Async/AwaitExpression.cs new file mode 100644 index 000000000..126cee898 --- /dev/null +++ b/src/Moq/Async/AwaitExpression.cs @@ -0,0 +1,40 @@ +// Copyright (c) 2007, Clarius Consulting, Manas Technology Solutions, InSTEDD, and Contributors. +// All rights reserved. Licensed under the BSD 3-Clause License; see License.txt. + +using System; +using System.Diagnostics; +using System.Linq.Expressions; + +namespace Moq.Async +{ + internal sealed class AwaitExpression : Expression + { + private readonly IAwaitableFactory awaitableFactory; + private readonly Expression operand; + + public AwaitExpression(Expression operand, IAwaitableFactory awaitableFactory) + { + Debug.Assert(awaitableFactory != null); + Debug.Assert(operand != null); + + this.awaitableFactory = awaitableFactory; + this.operand = operand; + } + + public override bool CanReduce => false; + + public override ExpressionType NodeType => ExpressionType.Extension; + + public Expression Operand => this.operand; + + public override Type Type => this.awaitableFactory.ResultType; + + public override string ToString() + { + return this.awaitableFactory.ResultType == typeof(void) ? $"await {this.operand}" + : $"(await {this.operand})"; + } + + protected override Expression VisitChildren(ExpressionVisitor visitor) => this; + } +} diff --git a/src/Moq/Async/AwaitableFactory`1.cs b/src/Moq/Async/AwaitableFactory`1.cs index eb45d228a..c3048d013 100644 --- a/src/Moq/Async/AwaitableFactory`1.cs +++ b/src/Moq/Async/AwaitableFactory`1.cs @@ -5,6 +5,7 @@ using System.Collections.Generic; using System.Diagnostics; using System.Linq; +using System.Linq.Expressions; namespace Moq.Async { @@ -44,6 +45,11 @@ object IAwaitableFactory.CreateFaulted(IEnumerable exceptions) return this.CreateFaulted(exceptions); } + Expression IAwaitableFactory.CreateResultExpression(Expression awaitableExpression) + { + return new AwaitExpression(awaitableExpression, this); + } + bool IAwaitableFactory.TryGetResult(object awaitable, out object result) { Debug.Assert(awaitable is TAwaitable); diff --git a/src/Moq/Async/AwaitableFactory`2.cs b/src/Moq/Async/AwaitableFactory`2.cs index 4cac4c456..5c5cdae85 100644 --- a/src/Moq/Async/AwaitableFactory`2.cs +++ b/src/Moq/Async/AwaitableFactory`2.cs @@ -5,6 +5,7 @@ using System.Collections.Generic; using System.Diagnostics; using System.Linq; +using System.Linq.Expressions; namespace Moq.Async { @@ -46,6 +47,8 @@ object IAwaitableFactory.CreateFaulted(IEnumerable exceptions) public abstract bool TryGetResult(TAwaitable awaitable, out TResult result); + public abstract Expression CreateResultExpression(Expression awaitableExpression); + bool IAwaitableFactory.TryGetResult(object awaitable, out object result) { Debug.Assert(awaitable is TAwaitable); diff --git a/src/Moq/Async/IAwaitableFactory.cs b/src/Moq/Async/IAwaitableFactory.cs index 1c38ca940..fb87828bb 100644 --- a/src/Moq/Async/IAwaitableFactory.cs +++ b/src/Moq/Async/IAwaitableFactory.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Linq.Expressions; namespace Moq.Async { @@ -16,6 +17,8 @@ internal interface IAwaitableFactory object CreateFaulted(IEnumerable exceptions); + Expression CreateResultExpression(Expression awaitableExpression); + bool TryGetResult(object awaitable, out object result); } } diff --git a/src/Moq/Async/TaskFactory.cs b/src/Moq/Async/TaskFactory.cs index 874751c6d..c396b2671 100644 --- a/src/Moq/Async/TaskFactory.cs +++ b/src/Moq/Async/TaskFactory.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Linq.Expressions; using System.Reflection; using System.Threading.Tasks; diff --git a/src/Moq/Async/TaskFactory`1.cs b/src/Moq/Async/TaskFactory`1.cs index 41d5dde6a..2f19662f0 100644 --- a/src/Moq/Async/TaskFactory`1.cs +++ b/src/Moq/Async/TaskFactory`1.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Linq.Expressions; using System.Threading.Tasks; namespace Moq.Async @@ -28,6 +29,13 @@ public override Task CreateFaulted(IEnumerable exceptions) return tcs.Task; } + public override Expression CreateResultExpression(Expression awaitableExpression) + { + return Expression.MakeMemberAccess( + awaitableExpression, + typeof(Task).GetProperty(nameof(Task.Result))); + } + public override bool TryGetResult(Task task, out TResult result) { if (task.Status == TaskStatus.RanToCompletion) diff --git a/src/Moq/Async/ValueTaskFactory`1.cs b/src/Moq/Async/ValueTaskFactory`1.cs index 213921bf7..69db7d841 100644 --- a/src/Moq/Async/ValueTaskFactory`1.cs +++ b/src/Moq/Async/ValueTaskFactory`1.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Linq.Expressions; using System.Threading.Tasks; namespace Moq.Async @@ -28,6 +29,13 @@ public override ValueTask CreateFaulted(IEnumerable exceptio return new ValueTask(tcs.Task); } + public override Expression CreateResultExpression(Expression awaitableExpression) + { + return Expression.MakeMemberAccess( + awaitableExpression, + typeof(ValueTask).GetProperty(nameof(ValueTask.Result))); + } + public override bool TryGetResult(ValueTask valueTask, out TResult result) { if (valueTask.IsCompletedSuccessfully)