From a53e84ef6520d019510af3d256182d264ab83e41 Mon Sep 17 00:00:00 2001 From: Alex Tercete Date: Tue, 10 Dec 2013 20:32:16 -0200 Subject: [PATCH] Avoid NullReferenceException with async methods When using async methods and loose Mock Behavior, calls to Task methods give NullReferenceException, since default(Task) is null. Calls to Task.Wait() and Task.Result should follow the loose behavior: do nothing and return the default value for T, respectively. Related to #64. --- Source/EmptyDefaultValueProvider.cs | 25 ++++++++++++ UnitTests/EmptyDefaultValueProviderFixture.cs | 39 +++++++++++++++++++ 2 files changed, 64 insertions(+) diff --git a/Source/EmptyDefaultValueProvider.cs b/Source/EmptyDefaultValueProvider.cs index c47930867..4f003d77a 100644 --- a/Source/EmptyDefaultValueProvider.cs +++ b/Source/EmptyDefaultValueProvider.cs @@ -43,6 +43,7 @@ using System.Collections.Generic; using System.Linq; using System.Reflection; +using System.Threading.Tasks; namespace Moq { @@ -86,6 +87,11 @@ private static object GetReferenceTypeDefault(Type valueType) { return new object[0].AsQueryable(); } + else if (valueType == typeof(Task)) + { + // Task inherits from Task, so just return Task + return GetCompletedTaskWithResult(false); + } else if (valueType.IsGenericType && valueType.GetGenericTypeDefinition() == typeof(IEnumerable<>)) { var genericListType = typeof(List<>).MakeGenericType(valueType.GetGenericArguments()[0]); @@ -101,6 +107,13 @@ private static object GetReferenceTypeDefault(Type valueType) .MakeGenericMethod(genericType) .Invoke(null, new[] { Activator.CreateInstance(genericListType) }); } + else if (valueType.IsGenericType && valueType.GetGenericTypeDefinition() == typeof(Task<>)) + { + var genericType = valueType.GetGenericArguments()[0]; + + return GetCompletedTaskWithResult( + genericType.IsValueType ? GetValueTypeDefault(genericType) : GetReferenceTypeDefault(genericType)); + } return null; } @@ -115,5 +128,17 @@ private static object GetValueTypeDefault(Type valueType) return Activator.CreateInstance(valueType); } + + private static Task GetCompletedTaskWithResult(TResult value) + { + var type = value.GetType(); + var tcs = Activator.CreateInstance(typeof (TaskCompletionSource<>).MakeGenericType(type)); + + var setResultMethod = tcs.GetType().GetMethod("SetResult"); + var taskProperty = tcs.GetType().GetProperty("Task"); + + setResultMethod.Invoke(tcs, new object[] {value}); + return (Task) taskProperty.GetValue(tcs, null); + } } } diff --git a/UnitTests/EmptyDefaultValueProviderFixture.cs b/UnitTests/EmptyDefaultValueProviderFixture.cs index a02732444..23ef22ef3 100644 --- a/UnitTests/EmptyDefaultValueProviderFixture.cs +++ b/UnitTests/EmptyDefaultValueProviderFixture.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Threading.Tasks; using Xunit; namespace Moq.Tests @@ -107,6 +108,41 @@ public void ProvideEmptyQueryableObjects() Assert.Equal(0, ((IQueryable)value).Cast().Count()); } + [Fact] + public void ProvidesDefaultTask() + { + var provider = new EmptyDefaultValueProvider(); + + var value = provider.ProvideDefault(typeof(IFoo).GetProperty("TaskValue").GetGetMethod()); + + Assert.NotNull(value); + Assert.True(((Task)value).IsCompleted); + } + + [Fact] + public void ProvidesDefaultGenericTask() + { + var provider = new EmptyDefaultValueProvider(); + + var value = provider.ProvideDefault(typeof(IFoo).GetProperty("GenericTaskValue").GetGetMethod()); + + Assert.NotNull(value); + Assert.True(((Task)value).IsCompleted); + Assert.Equal(default(int), ((Task)value).Result); + } + + [Fact] + public void ProvidesDefaultTaskOfGenericTask() + { + var provider = new EmptyDefaultValueProvider(); + + var value = provider.ProvideDefault(typeof(IFoo).GetProperty("TaskOfGenericTaskValue").GetGetMethod()); + + Assert.NotNull(value); + Assert.True(((Task)value).IsCompleted); + Assert.Equal(default(int), ((Task>) value).Result.Result); + } + public interface IFoo { object Object { get; set; } @@ -120,6 +156,9 @@ public interface IFoo IBar[] Bars { get; set; } IQueryable Queryable { get; } IQueryable QueryableObjects { get; } + Task TaskValue { get; set; } + Task GenericTaskValue { get; set; } + Task> TaskOfGenericTaskValue { get; set; } } public interface IBar { }