Skip to content

Commit

Permalink
Send InvocationCancel request to out-of-proc workers (#8556)
Browse files Browse the repository at this point in the history
  • Loading branch information
liliankasem authored Aug 15, 2022
1 parent c5dba8a commit 91e0d11
Show file tree
Hide file tree
Showing 10 changed files with 205 additions and 31 deletions.
3 changes: 3 additions & 0 deletions release_notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
<!-- Please add your release notes in the following format:
- My change description (#PR)
-->

- Host support for out-of-proc cancellation tokens ([#2153](https://github.com/Azure/azure-functions-host/issues/2152))
- Update Python Worker Version to [4.4.0](https://github.com/Azure/azure-functions-python-worker/releases/tag/4.4.0)

**Release sprint:** Sprint 125
[ [bugs](https://github.com/Azure/azure-functions-host/issues?q=is%3Aissue+milestone%3A%22Functions+Sprint+125%22+label%3Abug+is%3Aclosed) | [features](https://github.com/Azure/azure-functions-host/issues?q=is%3Aissue+milestone%3A%22Functions+Sprint+125%22+label%3Afeature+is%3Aclosed) ]
- Fix the bug where debugging of dotnet isolated function apps hangs in visual studio (#8596)
Expand Down
50 changes: 37 additions & 13 deletions src/WebJobs.Script.Grpc/Channel/GrpcWorkerChannel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ internal class GrpcWorkerChannel : IRpcWorkerChannel, IDisposable
private TaskCompletionSource<List<RawFunctionMetadata>> _functionsIndexingTask = new TaskCompletionSource<List<RawFunctionMetadata>>(TaskCreationOptions.RunContinuationsAsynchronously);
private TimeSpan _functionLoadTimeout = TimeSpan.FromMinutes(1);
private bool _isSharedMemoryDataTransferEnabled;
private bool _cancelCapabilityEnabled;

private object _syncLock = new object();
private System.Timers.Timer _timer;
Expand Down Expand Up @@ -275,6 +276,7 @@ internal void WorkerInitResponse(GrpcEvent initEvent)
_state = _state | RpcWorkerChannelState.Initialized;
_workerCapabilities.UpdateCapabilities(_initMessage.Capabilities);
_isSharedMemoryDataTransferEnabled = IsSharedMemoryDataTransferEnabled();
_cancelCapabilityEnabled = !string.IsNullOrEmpty(_workerCapabilities.GetCapabilityState(RpcWorkerConstants.HandlesInvocationCancelMessage));

if (!_isSharedMemoryDataTransferEnabled)
{
Expand Down Expand Up @@ -501,36 +503,58 @@ internal async Task SendInvocationRequest(ScriptInvocationContext context)
_workerChannelLogger.LogDebug($"Function {context.FunctionMetadata.Name} failed to load");
context.ResultSource.TrySetException(_functionLoadErrors[context.FunctionMetadata.GetFunctionId()]);
_executingInvocations.TryRemove(context.ExecutionContext.InvocationId.ToString(), out ScriptInvocationContext _);
return;
}
else if (_metadataRequestErrors.ContainsKey(context.FunctionMetadata.GetFunctionId()))
{
_workerChannelLogger.LogDebug($"Worker failed to load metadata for {context.FunctionMetadata.Name}");
context.ResultSource.TrySetException(_metadataRequestErrors[context.FunctionMetadata.GetFunctionId()]);
_executingInvocations.TryRemove(context.ExecutionContext.InvocationId.ToString(), out ScriptInvocationContext _);
return;
}
else

if (context.CancellationToken.IsCancellationRequested)
{
if (context.CancellationToken.IsCancellationRequested)
{
context.ResultSource.SetCanceled();
return;
}
var invocationRequest = await context.ToRpcInvocationRequest(_workerChannelLogger, _workerCapabilities, _isSharedMemoryDataTransferEnabled, _sharedMemoryManager);
AddAdditionalTraceContext(invocationRequest.TraceContext.Attributes, context);
_executingInvocations.TryAdd(invocationRequest.InvocationId, context);
_workerChannelLogger.LogDebug("Cancellation has been requested, cancelling invocation request");
context.ResultSource.SetCanceled();
return;
}

SendStreamingMessage(new StreamingMessage
{
InvocationRequest = invocationRequest
});
var invocationRequest = await context.ToRpcInvocationRequest(_workerChannelLogger, _workerCapabilities, _isSharedMemoryDataTransferEnabled, _sharedMemoryManager);
AddAdditionalTraceContext(invocationRequest.TraceContext.Attributes, context);
_executingInvocations.TryAdd(invocationRequest.InvocationId, context);

if (_cancelCapabilityEnabled)
{
context.CancellationToken.Register(() => SendInvocationCancel(invocationRequest.InvocationId));
}

SendStreamingMessage(new StreamingMessage
{
InvocationRequest = invocationRequest
});
}
catch (Exception invokeEx)
{
context.ResultSource.TrySetException(invokeEx);
}
}

internal void SendInvocationCancel(string invocationId)
{
_workerChannelLogger.LogDebug($"Sending invocation cancel request for InvocationId {invocationId}");

var invocationCancel = new InvocationCancel
{
InvocationId = invocationId
};

SendStreamingMessage(new StreamingMessage
{
InvocationCancel = invocationCancel
});
}

// gets metadata from worker
public Task<List<RawFunctionMetadata>> GetFunctionMetadata()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,19 @@ public static bool IsFailure(this StatusResult statusResult, out Exception excep
}

/// <summary>
/// This method is only hit on the invocation code path. enableUserCodeExceptionCapability = feature flag,
/// exposed as a capability that is set by the worker.
/// This method is only hit on the invocation code path.
/// enableUserCodeExceptionCapability = feature flag exposed as a capability that is set by the worker.
/// </summary>
public static bool IsInvocationSuccess<T>(this StatusResult status, TaskCompletionSource<T> tcs, bool enableUserCodeExceptionCapability = false)
{
switch (status.Status)
{
case StatusResult.Types.Status.Failure:
case StatusResult.Types.Status.Cancelled:
var rpcException = GetRpcException(status, enableUserCodeExceptionCapability);
tcs.SetException(rpcException);
return false;

case StatusResult.Types.Status.Cancelled:
tcs.SetCanceled();
return false;

default:
return true;
}
Expand All @@ -65,6 +62,7 @@ public static Workers.Rpc.RpcException GetRpcException(StatusResult statusResult
{
return new Workers.Rpc.RpcException(status, ex.Message, ex.StackTrace, ex.Type, ex.IsUserException);
}

return new Workers.Rpc.RpcException(status, ex.Message, ex.StackTrace);
}
return new Workers.Rpc.RpcException(status, string.Empty, string.Empty);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public async Task OnTimeoutExceptionAsync(ExceptionDispatchInfo exceptionInfo, T

if (timeoutException?.Task != null)
{
// We may double the timeoutGracePeriod here by first waiting to see if the iniital
// We may double the timeoutGracePeriod here by first waiting to see if the initial
// function task that started the exception has completed.
Task completedTask = await Task.WhenAny(timeoutException.Task, Task.Delay(timeoutGracePeriod));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Linq;
using System.Reflection;
using System.Reflection.Emit;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Hosting;
using Microsoft.Azure.WebJobs.Description;
Expand Down Expand Up @@ -54,6 +55,9 @@ protected override async Task<Collection<ParameterDescriptor>> GetFunctionParame
{
var parameters = await base.GetFunctionParametersAsync(functionInvoker, functionMetadata, triggerMetadata, methodAttributes, inputBindings, outputBindings);

// Add cancellation token
parameters.Add(new ParameterDescriptor(ScriptConstants.SystemCancellationTokenParameterName, typeof(CancellationToken)));

var bindings = inputBindings.Union(outputBindings);

try
Expand Down
18 changes: 14 additions & 4 deletions src/WebJobs.Script/Description/Workers/WorkerFunctionInvoker.cs
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,10 @@ protected override async Task<object> InvokeCore(object[] parameters, FunctionIn
await DelayUntilFunctionDispatcherInitializedOrShutdown();
}

var triggerParameterIndex = 0;
var cancellationTokenParameterIndex = 4;
var bindingData = context.Binder.BindingData;
object triggerValue = TransformInput(parameters[0], bindingData);
object triggerValue = TransformInput(parameters[triggerParameterIndex], bindingData);
var triggerInput = (_bindingMetadata.Name, _bindingMetadata.DataType ?? DataType.String, triggerValue);
IEnumerable<(string, DataType, object)> inputs = new[] { triggerInput };
if (_inputBindings.Count > 1)
Expand All @@ -84,9 +86,7 @@ protected override async Task<object> InvokeCore(object[] parameters, FunctionIn
Traceparent = Activity.Current?.Id,
Tracestate = Activity.Current?.TraceStateString,
Attributes = Activity.Current?.Tags,

// TODO: link up cancellation token to parameter descriptors
CancellationToken = CancellationToken.None,
CancellationToken = HandleCancellationTokenParameter(parameters[cancellationTokenParameterIndex]),
Logger = context.Logger
};

Expand Down Expand Up @@ -187,6 +187,16 @@ private object TransformInput(object input, Dictionary<string, object> bindingDa
return input;
}

private CancellationToken HandleCancellationTokenParameter(object input)
{
if (input == null)
{
return CancellationToken.None;
}

return (CancellationToken)input;
}

private void HandleReturnParameter(ScriptInvocationResult result)
{
result.Outputs[ScriptConstants.SystemReturnParameterBindingName] = result.Return;
Expand Down
2 changes: 1 addition & 1 deletion src/WebJobs.Script/Host/ScriptHost.cs
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ public async Task InitializeAsync(CancellationToken cancellationToken = default)
{
// Initialize worker function invocation dispatcher only for valid functions after creating function descriptors
// Dispatcher not needed for codeless function.
// Disptacher needed for non-dotnet codeless functions
// Dispatcher needed for non-dotnet codeless functions
var filteredFunctionMetadata = functionMetadataList.Where(m => !Utility.IsCodelessDotNetLanguageFunction(m));
await _functionDispatcher.InitializeAsync(Utility.GetValidFunctions(filteredFunctionMetadata, Functions), cancellationToken);
}
Expand Down
1 change: 1 addition & 0 deletions src/WebJobs.Script/ScriptConstants.cs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ public static class ScriptConstants
public const string SystemReturnParameterBindingName = "$return";
public const string SystemReturnParameterName = "_return";
public const string SystemLoggerParameterName = "_logger";
public const string SystemCancellationTokenParameterName = "_cancellationToken";

public const string DebugSentinelFileName = "debug_sentinel";
public const string DiagnosticSentinelFileName = "diagnostic_sentinel";
Expand Down
1 change: 1 addition & 0 deletions src/WebJobs.Script/Workers/Rpc/RpcWorkerConstants.cs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ public static class RpcWorkerConstants
public const string EnableUserCodeException = "EnableUserCodeException";
public const string SupportsLoadResponseCollection = "SupportsLoadResponseCollection";
public const string HandlesWorkerTerminateMessage = "HandlesWorkerTerminateMessage";
public const string HandlesInvocationCancelMessage = "HandlesInvocationCancelMessage";

// Host Capabilities
public const string V2Compatable = "V2Compatable";
Expand Down
Loading

0 comments on commit 91e0d11

Please sign in to comment.