Skip to content

Commit

Permalink
Pass invocation with proceedInfo
Browse files Browse the repository at this point in the history
  • Loading branch information
JSkimming committed Mar 25, 2019
1 parent 96d9a0f commit d7b90f8
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 27 deletions.
35 changes: 22 additions & 13 deletions src/Castle.Core.AsyncInterceptor/AsyncInterceptorBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ void IAsyncInterceptor.InterceptSynchronous(IInvocation invocation)
/// <param name="invocation">The method invocation.</param>
void IAsyncInterceptor.InterceptAsynchronous(IInvocation invocation)
{
invocation.ReturnValue = InterceptAsync(invocation.GetProceedInfo(), ProceedAsynchronous);
invocation.ReturnValue = InterceptAsync(invocation, invocation.GetProceedInfo(), ProceedAsynchronous);
}

/// <summary>
Expand All @@ -61,29 +61,34 @@ void IAsyncInterceptor.InterceptAsynchronous(IInvocation invocation)
/// <param name="invocation">The method invocation.</param>
void IAsyncInterceptor.InterceptAsynchronous<TResult>(IInvocation invocation)
{
invocation.ReturnValue = InterceptAsync(invocation.GetProceedInfo(), ProceedAsynchronous<TResult>);
invocation.ReturnValue =
InterceptAsync(invocation, invocation.GetProceedInfo(), ProceedAsynchronous<TResult>);
}

/// <summary>
/// Override in derived classes to intercept method invocations.
/// </summary>
/// <param name="invocation">The method invocation.</param>
/// <param name="proceedInfo">The <see cref="IInvocationProceedInfo"/>.</param>
/// <param name="proceed">The function to proceed the <paramref name="proceedInfo"/>.</param>
/// <returns>A <see cref="Task" /> object that represents the asynchronous operation.</returns>
protected abstract Task InterceptAsync(
IInvocation invocation,
IInvocationProceedInfo proceedInfo,
Func<IInvocationProceedInfo, Task> proceed);
Func<IInvocation, IInvocationProceedInfo, Task> proceed);

/// <summary>
/// Override in derived classes to intercept method invocations.
/// </summary>
/// <typeparam name="TResult">The type of the <see cref="Task{T}"/> <see cref="Task{T}.Result"/>.</typeparam>
/// <param name="invocation">The method invocation.</param>
/// <param name="proceedInfo">The <see cref="IInvocationProceedInfo"/>.</param>
/// <param name="proceed">The function to proceed the <paramref name="proceedInfo"/>.</param>
/// <returns>A <see cref="Task" /> object that represents the asynchronous operation.</returns>
protected abstract Task<TResult> InterceptAsync<TResult>(
IInvocation invocation,
IInvocationProceedInfo proceedInfo,
Func<IInvocationProceedInfo, Task<TResult>> proceed);
Func<IInvocation, IInvocationProceedInfo, Task<TResult>> proceed);

private static GenericSynchronousHandler CreateHandler(Type returnType)
{
Expand All @@ -93,7 +98,7 @@ private static GenericSynchronousHandler CreateHandler(Type returnType)

private static void InterceptSynchronousVoid(AsyncInterceptorBase me, IInvocation invocation)
{
Task task = me.InterceptAsync(invocation.GetProceedInfo(), ProceedSynchronous);
Task task = me.InterceptAsync(invocation, invocation.GetProceedInfo(), ProceedSynchronous);

// If the intercept task has yet to complete, wait for it.
if (!task.IsCompleted)
Expand All @@ -112,7 +117,7 @@ private static void InterceptSynchronousVoid(AsyncInterceptorBase me, IInvocatio

private static void InterceptSynchronousResult<TResult>(AsyncInterceptorBase me, IInvocation invocation)
{
Task<TResult> task = me.InterceptAsync(invocation.GetProceedInfo(), ProceedSynchronous<TResult>);
Task<TResult> task = me.InterceptAsync(invocation, invocation.GetProceedInfo(), ProceedSynchronous<TResult>);

// If the intercept task has yet to complete, wait for it.
if (!task.IsCompleted)
Expand All @@ -129,7 +134,7 @@ private static void InterceptSynchronousResult<TResult>(AsyncInterceptorBase me,
}
}

private static Task ProceedSynchronous(IInvocationProceedInfo proceedInfo)
private static Task ProceedSynchronous(IInvocation invocation, IInvocationProceedInfo proceedInfo)
{
try
{
Expand All @@ -152,12 +157,14 @@ private static Task ProceedSynchronous(IInvocationProceedInfo proceedInfo)
}
}

private static Task<TResult> ProceedSynchronous<TResult>(IInvocationProceedInfo proceedInfo)
private static Task<TResult> ProceedSynchronous<TResult>(
IInvocation invocation,
IInvocationProceedInfo proceedInfo)
{
try
{
proceedInfo.Invoke();
return Task.FromResult((TResult)proceedInfo.Invocation.ReturnValue);
return Task.FromResult((TResult)invocation.ReturnValue);
}
catch (Exception e)
{
Expand All @@ -171,22 +178,24 @@ private static Task<TResult> ProceedSynchronous<TResult>(IInvocationProceedInfo
}
}

private static async Task ProceedAsynchronous(IInvocationProceedInfo proceedInfo)
private static async Task ProceedAsynchronous(IInvocation invocation, IInvocationProceedInfo proceedInfo)
{
proceedInfo.Invoke();

// Get the task to await.
var originalReturnValue = (Task)proceedInfo.Invocation.ReturnValue;
var originalReturnValue = (Task)invocation.ReturnValue;

await originalReturnValue.ConfigureAwait(false);
}

private static async Task<TResult> ProceedAsynchronous<TResult>(IInvocationProceedInfo proceedInfo)
private static async Task<TResult> ProceedAsynchronous<TResult>(
IInvocation invocation,
IInvocationProceedInfo proceedInfo)
{
proceedInfo.Invoke();

// Get the task to await.
var originalReturnValue = (Task<TResult>)proceedInfo.Invocation.ReturnValue;
var originalReturnValue = (Task<TResult>)invocation.ReturnValue;

TResult result = await originalReturnValue.ConfigureAwait(false);
return result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,16 +280,20 @@ public class WhenExceptionInterceptingAnAsynchronousMethodThatThrowsASynchronous
{
private class MyInterceptorBase : AsyncInterceptorBase
{
protected override Task InterceptAsync(IInvocationProceedInfo proceedInfo, Func<IInvocationProceedInfo, Task> proceed)
protected override Task InterceptAsync(
IInvocation invocation,
IInvocationProceedInfo proceedInfo,
Func<IInvocation, IInvocationProceedInfo, Task> proceed)
{
return proceed(proceedInfo);
return proceed(invocation, proceedInfo);
}

protected override Task<TResult> InterceptAsync<TResult>(
IInvocation invocation,
IInvocationProceedInfo proceedInfo,
Func<IInvocationProceedInfo, Task<TResult>> proceed)
Func<IInvocation, IInvocationProceedInfo, Task<TResult>> proceed)
{
return proceed(proceedInfo);
return proceed(invocation, proceedInfo);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,46 +18,50 @@ public TestAsyncInterceptorBase(ListLogger log, int msDeley)
_msDeley = msDeley;
}

protected override async Task InterceptAsync(IInvocationProceedInfo proceedInfo, Func<IInvocationProceedInfo, Task> proceed)
protected override async Task InterceptAsync(
IInvocation invocation,
IInvocationProceedInfo proceedInfo,
Func<IInvocation, IInvocationProceedInfo, Task> proceed)
{
try
{
_log.Add($"{proceedInfo.Invocation.Method.Name}:StartingVoidInvocation");
_log.Add($"{invocation.Method.Name}:StartingVoidInvocation");

await Task.Yield();
await proceed(proceedInfo).ConfigureAwait(false);
await proceed(invocation, proceedInfo).ConfigureAwait(false);

if (_msDeley > 0)
await Task.Delay(_msDeley).ConfigureAwait(false);

_log.Add($"{proceedInfo.Invocation.Method.Name}:CompletedVoidInvocation");
_log.Add($"{invocation.Method.Name}:CompletedVoidInvocation");
}
catch (Exception e)
{
_log.Add($"{proceedInfo.Invocation.Method.Name}:VoidExceptionThrown:{e.Message}");
_log.Add($"{invocation.Method.Name}:VoidExceptionThrown:{e.Message}");
throw;
}
}

protected override async Task<TResult> InterceptAsync<TResult>(
IInvocation invocation,
IInvocationProceedInfo proceedInfo,
Func<IInvocationProceedInfo, Task<TResult>> proceed)
Func<IInvocation, IInvocationProceedInfo, Task<TResult>> proceed)
{
try
{
_log.Add($"{proceedInfo.Invocation.Method.Name}:StartingResultInvocation");
_log.Add($"{invocation.Method.Name}:StartingResultInvocation");

TResult result = await proceed(proceedInfo).ConfigureAwait(false);
TResult result = await proceed(invocation, proceedInfo).ConfigureAwait(false);

if (_msDeley > 0)
await Task.Delay(_msDeley).ConfigureAwait(false);

_log.Add($"{proceedInfo.Invocation.Method.Name}:CompletedResultInvocation");
_log.Add($"{invocation.Method.Name}:CompletedResultInvocation");
return result;
}
catch (Exception e)
{
_log.Add($"{proceedInfo.Invocation.Method.Name}:ResultExceptionThrown:{e.Message}");
_log.Add($"{invocation.Method.Name}:ResultExceptionThrown:{e.Message}");
throw;
}
}
Expand Down

0 comments on commit d7b90f8

Please sign in to comment.