diff --git a/release_notes.md b/release_notes.md index db3c10d3e9..336befb50a 100644 --- a/release_notes.md +++ b/release_notes.md @@ -2,6 +2,9 @@ + +- Host support for out-of-proc cancellation tokens ([#2153](https://github.com/Azure/azure-functions-host/issues/2152)) - Update Python Worker Version to [4.4.0](https://github.com/Azure/azure-functions-python-worker/releases/tag/4.4.0) + **Release sprint:** Sprint 125 [ [bugs](https://github.com/Azure/azure-functions-host/issues?q=is%3Aissue+milestone%3A%22Functions+Sprint+125%22+label%3Abug+is%3Aclosed) | [features](https://github.com/Azure/azure-functions-host/issues?q=is%3Aissue+milestone%3A%22Functions+Sprint+125%22+label%3Afeature+is%3Aclosed) ] diff --git a/src/WebJobs.Script.Grpc/Channel/GrpcWorkerChannel.cs b/src/WebJobs.Script.Grpc/Channel/GrpcWorkerChannel.cs index 2e7ab1f6a9..f8ce55dd44 100644 --- a/src/WebJobs.Script.Grpc/Channel/GrpcWorkerChannel.cs +++ b/src/WebJobs.Script.Grpc/Channel/GrpcWorkerChannel.cs @@ -73,6 +73,7 @@ internal class GrpcWorkerChannel : IRpcWorkerChannel, IDisposable private TaskCompletionSource> _functionsIndexingTask = new TaskCompletionSource>(TaskCreationOptions.RunContinuationsAsynchronously); private TimeSpan _functionLoadTimeout = TimeSpan.FromMinutes(1); private bool _isSharedMemoryDataTransferEnabled; + private bool _cancelCapabilityEnabled; private object _syncLock = new object(); private System.Timers.Timer _timer; @@ -275,6 +276,7 @@ internal void WorkerInitResponse(GrpcEvent initEvent) _state = _state | RpcWorkerChannelState.Initialized; _workerCapabilities.UpdateCapabilities(_initMessage.Capabilities); _isSharedMemoryDataTransferEnabled = IsSharedMemoryDataTransferEnabled(); + _cancelCapabilityEnabled = !string.IsNullOrEmpty(_workerCapabilities.GetCapabilityState(RpcWorkerConstants.HandlesInvocationCancelMessage)); if (!_isSharedMemoryDataTransferEnabled) { @@ -501,29 +503,36 @@ internal async Task SendInvocationRequest(ScriptInvocationContext context) _workerChannelLogger.LogDebug($"Function {context.FunctionMetadata.Name} failed to load"); context.ResultSource.TrySetException(_functionLoadErrors[context.FunctionMetadata.GetFunctionId()]); _executingInvocations.TryRemove(context.ExecutionContext.InvocationId.ToString(), out ScriptInvocationContext _); + return; } else if (_metadataRequestErrors.ContainsKey(context.FunctionMetadata.GetFunctionId())) { _workerChannelLogger.LogDebug($"Worker failed to load metadata for {context.FunctionMetadata.Name}"); context.ResultSource.TrySetException(_metadataRequestErrors[context.FunctionMetadata.GetFunctionId()]); _executingInvocations.TryRemove(context.ExecutionContext.InvocationId.ToString(), out ScriptInvocationContext _); + return; } - else + + if (context.CancellationToken.IsCancellationRequested) { - if (context.CancellationToken.IsCancellationRequested) - { - context.ResultSource.SetCanceled(); - return; - } - var invocationRequest = await context.ToRpcInvocationRequest(_workerChannelLogger, _workerCapabilities, _isSharedMemoryDataTransferEnabled, _sharedMemoryManager); - AddAdditionalTraceContext(invocationRequest.TraceContext.Attributes, context); - _executingInvocations.TryAdd(invocationRequest.InvocationId, context); + _workerChannelLogger.LogDebug("Cancellation has been requested, cancelling invocation request"); + context.ResultSource.SetCanceled(); + return; + } - SendStreamingMessage(new StreamingMessage - { - InvocationRequest = invocationRequest - }); + var invocationRequest = await context.ToRpcInvocationRequest(_workerChannelLogger, _workerCapabilities, _isSharedMemoryDataTransferEnabled, _sharedMemoryManager); + AddAdditionalTraceContext(invocationRequest.TraceContext.Attributes, context); + _executingInvocations.TryAdd(invocationRequest.InvocationId, context); + + if (_cancelCapabilityEnabled) + { + context.CancellationToken.Register(() => SendInvocationCancel(invocationRequest.InvocationId)); } + + SendStreamingMessage(new StreamingMessage + { + InvocationRequest = invocationRequest + }); } catch (Exception invokeEx) { @@ -531,6 +540,21 @@ internal async Task SendInvocationRequest(ScriptInvocationContext context) } } + internal void SendInvocationCancel(string invocationId) + { + _workerChannelLogger.LogDebug($"Sending invocation cancel request for InvocationId {invocationId}"); + + var invocationCancel = new InvocationCancel + { + InvocationId = invocationId + }; + + SendStreamingMessage(new StreamingMessage + { + InvocationCancel = invocationCancel + }); + } + // gets metadata from worker public Task> GetFunctionMetadata() { diff --git a/src/WebJobs.Script.Grpc/MessageExtensions/StatusResultExtensions.cs b/src/WebJobs.Script.Grpc/MessageExtensions/StatusResultExtensions.cs index b2ff6656c4..b555e6c192 100644 --- a/src/WebJobs.Script.Grpc/MessageExtensions/StatusResultExtensions.cs +++ b/src/WebJobs.Script.Grpc/MessageExtensions/StatusResultExtensions.cs @@ -30,22 +30,19 @@ public static bool IsFailure(this StatusResult statusResult, out Exception excep } /// - /// This method is only hit on the invocation code path. enableUserCodeExceptionCapability = feature flag, - /// exposed as a capability that is set by the worker. + /// This method is only hit on the invocation code path. + /// enableUserCodeExceptionCapability = feature flag exposed as a capability that is set by the worker. /// public static bool IsInvocationSuccess(this StatusResult status, TaskCompletionSource tcs, bool enableUserCodeExceptionCapability = false) { switch (status.Status) { case StatusResult.Types.Status.Failure: + case StatusResult.Types.Status.Cancelled: var rpcException = GetRpcException(status, enableUserCodeExceptionCapability); tcs.SetException(rpcException); return false; - case StatusResult.Types.Status.Cancelled: - tcs.SetCanceled(); - return false; - default: return true; } @@ -65,6 +62,7 @@ public static Workers.Rpc.RpcException GetRpcException(StatusResult statusResult { return new Workers.Rpc.RpcException(status, ex.Message, ex.StackTrace, ex.Type, ex.IsUserException); } + return new Workers.Rpc.RpcException(status, ex.Message, ex.StackTrace); } return new Workers.Rpc.RpcException(status, string.Empty, string.Empty); diff --git a/src/WebJobs.Script.WebHost/WebScriptHostExceptionHandler.cs b/src/WebJobs.Script.WebHost/WebScriptHostExceptionHandler.cs index 162e7b0ca3..cd8b364caf 100644 --- a/src/WebJobs.Script.WebHost/WebScriptHostExceptionHandler.cs +++ b/src/WebJobs.Script.WebHost/WebScriptHostExceptionHandler.cs @@ -31,7 +31,7 @@ public async Task OnTimeoutExceptionAsync(ExceptionDispatchInfo exceptionInfo, T if (timeoutException?.Task != null) { - // We may double the timeoutGracePeriod here by first waiting to see if the iniital + // We may double the timeoutGracePeriod here by first waiting to see if the initial // function task that started the exception has completed. Task completedTask = await Task.WhenAny(timeoutException.Task, Task.Delay(timeoutGracePeriod)); diff --git a/src/WebJobs.Script/Description/Workers/WorkerFunctionDescriptorProvider.cs b/src/WebJobs.Script/Description/Workers/WorkerFunctionDescriptorProvider.cs index ebe03f19fb..2a41757232 100644 --- a/src/WebJobs.Script/Description/Workers/WorkerFunctionDescriptorProvider.cs +++ b/src/WebJobs.Script/Description/Workers/WorkerFunctionDescriptorProvider.cs @@ -7,6 +7,7 @@ using System.Linq; using System.Reflection; using System.Reflection.Emit; +using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Hosting; using Microsoft.Azure.WebJobs.Description; @@ -54,6 +55,9 @@ protected override async Task> GetFunctionParame { var parameters = await base.GetFunctionParametersAsync(functionInvoker, functionMetadata, triggerMetadata, methodAttributes, inputBindings, outputBindings); + // Add cancellation token + parameters.Add(new ParameterDescriptor(ScriptConstants.SystemCancellationTokenParameterName, typeof(CancellationToken))); + var bindings = inputBindings.Union(outputBindings); try diff --git a/src/WebJobs.Script/Description/Workers/WorkerFunctionInvoker.cs b/src/WebJobs.Script/Description/Workers/WorkerFunctionInvoker.cs index f52e5eb17f..fa012a3624 100644 --- a/src/WebJobs.Script/Description/Workers/WorkerFunctionInvoker.cs +++ b/src/WebJobs.Script/Description/Workers/WorkerFunctionInvoker.cs @@ -63,8 +63,10 @@ protected override async Task InvokeCore(object[] parameters, FunctionIn await DelayUntilFunctionDispatcherInitializedOrShutdown(); } + var triggerParameterIndex = 0; + var cancellationTokenParameterIndex = 4; var bindingData = context.Binder.BindingData; - object triggerValue = TransformInput(parameters[0], bindingData); + object triggerValue = TransformInput(parameters[triggerParameterIndex], bindingData); var triggerInput = (_bindingMetadata.Name, _bindingMetadata.DataType ?? DataType.String, triggerValue); IEnumerable<(string, DataType, object)> inputs = new[] { triggerInput }; if (_inputBindings.Count > 1) @@ -84,9 +86,7 @@ protected override async Task InvokeCore(object[] parameters, FunctionIn Traceparent = Activity.Current?.Id, Tracestate = Activity.Current?.TraceStateString, Attributes = Activity.Current?.Tags, - - // TODO: link up cancellation token to parameter descriptors - CancellationToken = CancellationToken.None, + CancellationToken = HandleCancellationTokenParameter(parameters[cancellationTokenParameterIndex]), Logger = context.Logger }; @@ -187,6 +187,16 @@ private object TransformInput(object input, Dictionary bindingDa return input; } + private CancellationToken HandleCancellationTokenParameter(object input) + { + if (input == null) + { + return CancellationToken.None; + } + + return (CancellationToken)input; + } + private void HandleReturnParameter(ScriptInvocationResult result) { result.Outputs[ScriptConstants.SystemReturnParameterBindingName] = result.Return; diff --git a/src/WebJobs.Script/Host/ScriptHost.cs b/src/WebJobs.Script/Host/ScriptHost.cs index 92c75b84ca..0fd66b8472 100644 --- a/src/WebJobs.Script/Host/ScriptHost.cs +++ b/src/WebJobs.Script/Host/ScriptHost.cs @@ -317,7 +317,7 @@ public async Task InitializeAsync(CancellationToken cancellationToken = default) { // Initialize worker function invocation dispatcher only for valid functions after creating function descriptors // Dispatcher not needed for codeless function. - // Disptacher needed for non-dotnet codeless functions + // Dispatcher needed for non-dotnet codeless functions var filteredFunctionMetadata = functionMetadataList.Where(m => !Utility.IsCodelessDotNetLanguageFunction(m)); await _functionDispatcher.InitializeAsync(Utility.GetValidFunctions(filteredFunctionMetadata, Functions), cancellationToken); } diff --git a/src/WebJobs.Script/ScriptConstants.cs b/src/WebJobs.Script/ScriptConstants.cs index c80cc29d1f..0f67baeafa 100644 --- a/src/WebJobs.Script/ScriptConstants.cs +++ b/src/WebJobs.Script/ScriptConstants.cs @@ -66,6 +66,7 @@ public static class ScriptConstants public const string SystemReturnParameterBindingName = "$return"; public const string SystemReturnParameterName = "_return"; public const string SystemLoggerParameterName = "_logger"; + public const string SystemCancellationTokenParameterName = "_cancellationToken"; public const string DebugSentinelFileName = "debug_sentinel"; public const string DiagnosticSentinelFileName = "diagnostic_sentinel"; diff --git a/src/WebJobs.Script/Workers/Rpc/RpcWorkerConstants.cs b/src/WebJobs.Script/Workers/Rpc/RpcWorkerConstants.cs index 5a522d2eaa..e890153dbd 100644 --- a/src/WebJobs.Script/Workers/Rpc/RpcWorkerConstants.cs +++ b/src/WebJobs.Script/Workers/Rpc/RpcWorkerConstants.cs @@ -52,6 +52,7 @@ public static class RpcWorkerConstants public const string EnableUserCodeException = "EnableUserCodeException"; public const string SupportsLoadResponseCollection = "SupportsLoadResponseCollection"; public const string HandlesWorkerTerminateMessage = "HandlesWorkerTerminateMessage"; + public const string HandlesInvocationCancelMessage = "HandlesInvocationCancelMessage"; // Host Capabilities public const string V2Compatable = "V2Compatable"; diff --git a/test/WebJobs.Script.Tests/Workers/Rpc/GrpcWorkerChannelTests.cs b/test/WebJobs.Script.Tests/Workers/Rpc/GrpcWorkerChannelTests.cs index 6ddbe1343c..bd5552a61e 100644 --- a/test/WebJobs.Script.Tests/Workers/Rpc/GrpcWorkerChannelTests.cs +++ b/test/WebJobs.Script.Tests/Workers/Rpc/GrpcWorkerChannelTests.cs @@ -55,7 +55,7 @@ public class GrpcWorkerChannelTests : IDisposable private readonly IFunctionDataCache _functionDataCache; private readonly IOptions _workerConcurrencyOptions; private GrpcWorkerChannel _workerChannel; - private GrpcWorkerChannel _workerChannelwithMockEventManager; + private GrpcWorkerChannel _workerChannelWithMockEventManager; public GrpcWorkerChannelTests() { @@ -110,7 +110,7 @@ public GrpcWorkerChannelTests() _workerConcurrencyOptions); _eventManagerMock.Setup(proxy => proxy.Publish(It.IsAny())).Verifiable(); - _workerChannelwithMockEventManager = new GrpcWorkerChannel( + _workerChannelWithMockEventManager = new GrpcWorkerChannel( _workerId, _eventManagerMock.Object, _testWorkerConfig, @@ -361,6 +361,138 @@ public async Task SendInvocationRequest_InputsTransferredOverSharedMemory() Assert.True(traces.Any(m => string.Equals(m.FormattedMessage, _expectedLogMsg))); } + [Fact] + public async Task SendInvocationRequest_SignalCancellation_WithCapability_SendsInvocationCancelRequest() + { + var cancellationWaitTimeMs = 3000; + var invocationId = Guid.NewGuid(); + var expectedCancellationLog = $"Sending invocation cancel request for InvocationId {invocationId.ToString()}"; + + IDictionary capabilities = new Dictionary() + { + { RpcWorkerConstants.HandlesInvocationCancelMessage, "1" } + }; + + var cts = new CancellationTokenSource(); + cts.CancelAfter(cancellationWaitTimeMs); + var token = cts.Token; + + var initTask = _workerChannel.StartWorkerProcessAsync(CancellationToken.None); + _testFunctionRpcService.PublishStartStreamEvent(_workerId); + _testFunctionRpcService.PublishWorkerInitResponseEvent(capabilities); + await initTask; + var scriptInvocationContext = GetTestScriptInvocationContext(invocationId, null, token); + await _workerChannel.SendInvocationRequest(scriptInvocationContext); + + while (!token.IsCancellationRequested) + { + await Task.Delay(1000); + if (token.IsCancellationRequested) + { + break; + } + } + + var traces = _logger.GetLogMessages(); + Assert.True(traces.Any(m => string.Equals(m.FormattedMessage, expectedCancellationLog))); + } + + [Fact] + public async Task SendInvocationRequest_SignalCancellation_WithoutCapability_NoAction() + { + var cancellationWaitTimeMs = 3000; + var invocationId = Guid.NewGuid(); + var expectedCancellationLog = $"Sending invocation cancel request for InvocationId {invocationId.ToString()}"; + + var cts = new CancellationTokenSource(); + cts.CancelAfter(cancellationWaitTimeMs); + var token = cts.Token; + + var initTask = _workerChannel.StartWorkerProcessAsync(CancellationToken.None); + _testFunctionRpcService.PublishStartStreamEvent(_workerId); + _testFunctionRpcService.PublishWorkerInitResponseEvent(); + await initTask; + var scriptInvocationContext = GetTestScriptInvocationContext(invocationId, null, token); + await _workerChannel.SendInvocationRequest(scriptInvocationContext); + + while (!token.IsCancellationRequested) + { + await Task.Delay(1000); + if (token.IsCancellationRequested) + { + break; + } + } + + var traces = _logger.GetLogMessages(); + Assert.False(traces.Any(m => string.Equals(m.FormattedMessage, expectedCancellationLog))); + } + + [Fact] + public async Task SendInvocationRequest_CancellationAlreadyRequested_ResultSourceCancelled() + { + var cancellationWaitTimeMs = 3000; + var invocationId = Guid.NewGuid(); + var expectedCancellationLog = "Cancellation has been requested, cancelling invocation request"; + + IDictionary capabilities = new Dictionary() + { + { RpcWorkerConstants.HandlesInvocationCancelMessage, "1" } + }; + + var cts = new CancellationTokenSource(); + cts.CancelAfter(cancellationWaitTimeMs); + var token = cts.Token; + + var initTask = _workerChannel.StartWorkerProcessAsync(CancellationToken.None); + _testFunctionRpcService.PublishStartStreamEvent(_workerId); + _testFunctionRpcService.PublishWorkerInitResponseEvent(capabilities); + await initTask; + + while (!token.IsCancellationRequested) + { + await Task.Delay(1000); + if (token.IsCancellationRequested) + { + break; + } + } + + var resultSource = new TaskCompletionSource(); + var scriptInvocationContext = GetTestScriptInvocationContext(invocationId, resultSource, token); + await _workerChannel.SendInvocationRequest(scriptInvocationContext); + + var traces = _logger.GetLogMessages(); + Assert.True(traces.Any(m => string.Equals(m.FormattedMessage, expectedCancellationLog))); + Assert.Equal(TaskStatus.Canceled, resultSource.Task.Status); + } + + [Fact] + public async Task SendInvocationCancelRequest_PublishesOutboundEvent() + { + var invocationId = Guid.NewGuid(); + var expectedCancellationLog = $"Sending invocation cancel request for InvocationId {invocationId.ToString()}"; + + IDictionary capabilities = new Dictionary() + { + { RpcWorkerConstants.HandlesInvocationCancelMessage, "1" } + }; + + var initTask = _workerChannel.StartWorkerProcessAsync(CancellationToken.None); + _testFunctionRpcService.PublishStartStreamEvent(_workerId); + _testFunctionRpcService.PublishWorkerInitResponseEvent(capabilities); + await initTask; + + var scriptInvocationContext = GetTestScriptInvocationContext(invocationId, null); + _workerChannel.SendInvocationCancel(invocationId.ToString()); + + var traces = _logger.GetLogMessages(); + Assert.True(traces.Any(m => string.Equals(m.FormattedMessage, expectedCancellationLog))); + Assert.True(traces.Any(m => string.Equals(m.FormattedMessage, _expectedLogMsg))); + // The outbound log should happen twice: once for worker init request and once for the invocation cancel request + Assert.Equal(traces.Where(m => m.FormattedMessage.Equals(_expectedLogMsg)).Count(), 2); + } + [Fact] public async Task Drain_Verify() { @@ -868,7 +1000,7 @@ public async Task GetLatencies_DoesNot_StartTimer_WhenDynamicConcurrencyDisabled public async Task SendInvocationRequest_ValidateTraceContext() { ScriptInvocationContext scriptInvocationContext = GetTestScriptInvocationContext(Guid.NewGuid(), null); - await _workerChannelwithMockEventManager.SendInvocationRequest(scriptInvocationContext); + await _workerChannelWithMockEventManager.SendInvocationRequest(scriptInvocationContext); if (_testEnvironment.IsApplicationInsightsAgentEnabled()) { _eventManagerMock.Verify(proxy => proxy.Publish(It.Is( @@ -895,7 +1027,7 @@ public async Task SendInvocationRequest_ValidateTraceContext_SessionId() activity.AddBaggage(ScriptConstants.LiveLogsSessionAIKey, sessionId); activity.Start(); ScriptInvocationContext scriptInvocationContext = GetTestScriptInvocationContext(Guid.NewGuid(), null); - await _workerChannelwithMockEventManager.SendInvocationRequest(scriptInvocationContext); + await _workerChannelWithMockEventManager.SendInvocationRequest(scriptInvocationContext); activity.Stop(); _eventManagerMock.Verify(p => p.Publish(It.Is(grpcEvent => ValidateInvocationRequest(grpcEvent, sessionId)))); } @@ -942,7 +1074,7 @@ private IEnumerable GetTestFunctionsList(string runtime) }; } - private ScriptInvocationContext GetTestScriptInvocationContext(Guid invocationId, TaskCompletionSource resultSource) + private ScriptInvocationContext GetTestScriptInvocationContext(Guid invocationId, TaskCompletionSource resultSource, CancellationToken? token = null) { return new ScriptInvocationContext() { @@ -956,7 +1088,8 @@ private ScriptInvocationContext GetTestScriptInvocationContext(Guid invocationId }, BindingData = new Dictionary(), Inputs = new List<(string name, DataType type, object val)>(), - ResultSource = resultSource + ResultSource = resultSource, + CancellationToken = token == null ? CancellationToken.None : (CancellationToken)token }; }