diff --git a/Source/EmptyDefaultValueProvider.cs b/Source/EmptyDefaultValueProvider.cs index c47930867..4f003d77a 100644 --- a/Source/EmptyDefaultValueProvider.cs +++ b/Source/EmptyDefaultValueProvider.cs @@ -43,6 +43,7 @@ using System.Collections.Generic; using System.Linq; using System.Reflection; +using System.Threading.Tasks; namespace Moq { @@ -86,6 +87,11 @@ private static object GetReferenceTypeDefault(Type valueType) { return new object[0].AsQueryable(); } + else if (valueType == typeof(Task)) + { + // Task inherits from Task, so just return Task + return GetCompletedTaskWithResult(false); + } else if (valueType.IsGenericType && valueType.GetGenericTypeDefinition() == typeof(IEnumerable<>)) { var genericListType = typeof(List<>).MakeGenericType(valueType.GetGenericArguments()[0]); @@ -101,6 +107,13 @@ private static object GetReferenceTypeDefault(Type valueType) .MakeGenericMethod(genericType) .Invoke(null, new[] { Activator.CreateInstance(genericListType) }); } + else if (valueType.IsGenericType && valueType.GetGenericTypeDefinition() == typeof(Task<>)) + { + var genericType = valueType.GetGenericArguments()[0]; + + return GetCompletedTaskWithResult( + genericType.IsValueType ? GetValueTypeDefault(genericType) : GetReferenceTypeDefault(genericType)); + } return null; } @@ -115,5 +128,17 @@ private static object GetValueTypeDefault(Type valueType) return Activator.CreateInstance(valueType); } + + private static Task GetCompletedTaskWithResult(TResult value) + { + var type = value.GetType(); + var tcs = Activator.CreateInstance(typeof (TaskCompletionSource<>).MakeGenericType(type)); + + var setResultMethod = tcs.GetType().GetMethod("SetResult"); + var taskProperty = tcs.GetType().GetProperty("Task"); + + setResultMethod.Invoke(tcs, new object[] {value}); + return (Task) taskProperty.GetValue(tcs, null); + } } } diff --git a/UnitTests/EmptyDefaultValueProviderFixture.cs b/UnitTests/EmptyDefaultValueProviderFixture.cs index a02732444..23ef22ef3 100644 --- a/UnitTests/EmptyDefaultValueProviderFixture.cs +++ b/UnitTests/EmptyDefaultValueProviderFixture.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Threading.Tasks; using Xunit; namespace Moq.Tests @@ -107,6 +108,41 @@ public void ProvideEmptyQueryableObjects() Assert.Equal(0, ((IQueryable)value).Cast().Count()); } + [Fact] + public void ProvidesDefaultTask() + { + var provider = new EmptyDefaultValueProvider(); + + var value = provider.ProvideDefault(typeof(IFoo).GetProperty("TaskValue").GetGetMethod()); + + Assert.NotNull(value); + Assert.True(((Task)value).IsCompleted); + } + + [Fact] + public void ProvidesDefaultGenericTask() + { + var provider = new EmptyDefaultValueProvider(); + + var value = provider.ProvideDefault(typeof(IFoo).GetProperty("GenericTaskValue").GetGetMethod()); + + Assert.NotNull(value); + Assert.True(((Task)value).IsCompleted); + Assert.Equal(default(int), ((Task)value).Result); + } + + [Fact] + public void ProvidesDefaultTaskOfGenericTask() + { + var provider = new EmptyDefaultValueProvider(); + + var value = provider.ProvideDefault(typeof(IFoo).GetProperty("TaskOfGenericTaskValue").GetGetMethod()); + + Assert.NotNull(value); + Assert.True(((Task)value).IsCompleted); + Assert.Equal(default(int), ((Task>) value).Result.Result); + } + public interface IFoo { object Object { get; set; } @@ -120,6 +156,9 @@ public interface IFoo IBar[] Bars { get; set; } IQueryable Queryable { get; } IQueryable QueryableObjects { get; } + Task TaskValue { get; set; } + Task GenericTaskValue { get; set; } + Task> TaskOfGenericTaskValue { get; set; } } public interface IBar { }