Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Send InvocationCancel request to out-of-proc workers #8556

Merged
merged 19 commits into from
Aug 15, 2022
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion release_notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,8 @@
<!-- Please add your release notes in the following format:
- My change description (#PR)
-->

- Host support for out-of-proc cancellation tokens ([#2153](https://github.com/Azure/azure-functions-host/issues/2152))

**Release sprint:** Sprint 125
[ [bugs](https://github.com/Azure/azure-functions-host/issues?q=is%3Aissue+milestone%3A%22Functions+Sprint+125%22+label%3Abug+is%3Aclosed) | [features](https://github.com/Azure/azure-functions-host/issues?q=is%3Aissue+milestone%3A%22Functions+Sprint+125%22+label%3Afeature+is%3Aclosed) ]
[ [bugs](https://github.com/Azure/azure-functions-host/issues?q=is%3Aissue+milestone%3A%22Functions+Sprint+125%22+label%3Abug+is%3Aclosed) | [features](https://github.com/Azure/azure-functions-host/issues?q=is%3Aissue+milestone%3A%22Functions+Sprint+125%22+label%3Afeature+is%3Aclosed) ]
53 changes: 39 additions & 14 deletions src/WebJobs.Script.Grpc/Channel/GrpcWorkerChannel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -500,36 +500,61 @@ internal async Task SendInvocationRequest(ScriptInvocationContext context)
_workerChannelLogger.LogDebug($"Function {context.FunctionMetadata.Name} failed to load");
context.ResultSource.TrySetException(_functionLoadErrors[context.FunctionMetadata.GetFunctionId()]);
_executingInvocations.TryRemove(context.ExecutionContext.InvocationId.ToString(), out ScriptInvocationContext _);
return;
}
else if (_metadataRequestErrors.ContainsKey(context.FunctionMetadata.GetFunctionId()))
{
_workerChannelLogger.LogDebug($"Worker failed to load metadata for {context.FunctionMetadata.Name}");
context.ResultSource.TrySetException(_metadataRequestErrors[context.FunctionMetadata.GetFunctionId()]);
_executingInvocations.TryRemove(context.ExecutionContext.InvocationId.ToString(), out ScriptInvocationContext _);
return;
}
else
{
if (context.CancellationToken.IsCancellationRequested)
{
context.ResultSource.SetCanceled();
return;
}
var invocationRequest = await context.ToRpcInvocationRequest(_workerChannelLogger, _workerCapabilities, _isSharedMemoryDataTransferEnabled, _sharedMemoryManager);
AddAdditionalTraceContext(invocationRequest.TraceContext.Attributes, context);
_executingInvocations.TryAdd(invocationRequest.InvocationId, context);

SendStreamingMessage(new StreamingMessage
{
InvocationRequest = invocationRequest
});
if (context.CancellationToken.IsCancellationRequested)
{
_workerChannelLogger.LogDebug("Cancellation has been requested, cancelling invocation request");
context.ResultSource.SetCanceled();
return;
}

var invocationRequest = await context.ToRpcInvocationRequest(_workerChannelLogger, _workerCapabilities, _isSharedMemoryDataTransferEnabled, _sharedMemoryManager);
AddAdditionalTraceContext(invocationRequest.TraceContext.Attributes, context);
_executingInvocations.TryAdd(invocationRequest.InvocationId, context);

context.CancellationToken.Register(() => SendInvocationCancel(invocationRequest.InvocationId));
liliankasem marked this conversation as resolved.
Show resolved Hide resolved
liliankasem marked this conversation as resolved.
Show resolved Hide resolved

SendStreamingMessage(new StreamingMessage
{
InvocationRequest = invocationRequest
});
}
catch (Exception invokeEx)
{
context.ResultSource.TrySetException(invokeEx);
}
}

internal void SendInvocationCancel(string invocationId)
liliankasem marked this conversation as resolved.
Show resolved Hide resolved
{
bool capabilityEnabled = !string.IsNullOrEmpty(_workerCapabilities.GetCapabilityState(RpcWorkerConstants.HandlesInvocationCancelMessage));
if (!capabilityEnabled)
{
return;
}

_workerChannelLogger.LogDebug($"Sending invocation cancel request for InvocationId {invocationId}");

var invocationCancel = new InvocationCancel
{
InvocationId = invocationId
};

SendStreamingMessage(new StreamingMessage
{
InvocationCancel = invocationCancel
});
}

// gets metadata from worker
public Task<List<RawFunctionMetadata>> GetFunctionMetadata()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public async Task OnTimeoutExceptionAsync(ExceptionDispatchInfo exceptionInfo, T

if (timeoutException?.Task != null)
{
// We may double the timeoutGracePeriod here by first waiting to see if the iniital
// We may double the timeoutGracePeriod here by first waiting to see if the initial
// function task that started the exception has completed.
Task completedTask = await Task.WhenAny(timeoutException.Task, Task.Delay(timeoutGracePeriod));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Linq;
using System.Reflection;
using System.Reflection.Emit;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Hosting;
using Microsoft.Azure.WebJobs.Description;
Expand Down Expand Up @@ -54,6 +55,9 @@ protected override async Task<Collection<ParameterDescriptor>> GetFunctionParame
{
var parameters = await base.GetFunctionParametersAsync(functionInvoker, functionMetadata, triggerMetadata, methodAttributes, inputBindings, outputBindings);

// Add cancellation token
parameters.Add(new ParameterDescriptor(ScriptConstants.SystemCancellationTokenParameterName, typeof(CancellationToken)));
liliankasem marked this conversation as resolved.
Show resolved Hide resolved

var bindings = inputBindings.Union(outputBindings);

try
Expand Down
18 changes: 14 additions & 4 deletions src/WebJobs.Script/Description/Workers/WorkerFunctionInvoker.cs
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,10 @@ protected override async Task<object> InvokeCore(object[] parameters, FunctionIn
await DelayUntilFunctionDispatcherInitializedOrShutdown();
}

var triggerParameterIndex = 0;
var cancellationTokenParameterIndex = 4;
var bindingData = context.Binder.BindingData;
object triggerValue = TransformInput(parameters[0], bindingData);
object triggerValue = TransformInput(parameters[triggerParameterIndex], bindingData);
var triggerInput = (_bindingMetadata.Name, _bindingMetadata.DataType ?? DataType.String, triggerValue);
IEnumerable<(string, DataType, object)> inputs = new[] { triggerInput };
if (_inputBindings.Count > 1)
Expand All @@ -84,9 +86,7 @@ protected override async Task<object> InvokeCore(object[] parameters, FunctionIn
Traceparent = Activity.Current?.Id,
Tracestate = Activity.Current?.TraceStateString,
Attributes = Activity.Current?.Tags,

// TODO: link up cancellation token to parameter descriptors
CancellationToken = CancellationToken.None,
CancellationToken = HandleCancellationTokenParameter(parameters[cancellationTokenParameterIndex]),
Logger = context.Logger
};

Expand Down Expand Up @@ -187,6 +187,16 @@ private object TransformInput(object input, Dictionary<string, object> bindingDa
return input;
}

private CancellationToken HandleCancellationTokenParameter(object input)
{
if (input == null)
{
return CancellationToken.None;
}

return (CancellationToken)input;
liliankasem marked this conversation as resolved.
Show resolved Hide resolved
}

private void HandleReturnParameter(ScriptInvocationResult result)
{
result.Outputs[ScriptConstants.SystemReturnParameterBindingName] = result.Return;
Expand Down
2 changes: 1 addition & 1 deletion src/WebJobs.Script/Host/ScriptHost.cs
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ public async Task InitializeAsync(CancellationToken cancellationToken = default)
{
// Initialize worker function invocation dispatcher only for valid functions after creating function descriptors
// Dispatcher not needed for codeless function.
// Disptacher needed for non-dotnet codeless functions
// Dispatcher needed for non-dotnet codeless functions
var filteredFunctionMetadata = functionMetadataList.Where(m => !Utility.IsCodelessDotNetLanguageFunction(m));
await _functionDispatcher.InitializeAsync(Utility.GetValidFunctions(filteredFunctionMetadata, Functions), cancellationToken);
}
Expand Down
1 change: 1 addition & 0 deletions src/WebJobs.Script/ScriptConstants.cs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ public static class ScriptConstants
public const string SystemReturnParameterBindingName = "$return";
public const string SystemReturnParameterName = "_return";
public const string SystemLoggerParameterName = "_logger";
public const string SystemCancellationTokenParameterName = "_cancellationToken";

public const string DebugSentinelFileName = "debug_sentinel";
public const string DiagnosticSentinelFileName = "diagnostic_sentinel";
Expand Down
1 change: 1 addition & 0 deletions src/WebJobs.Script/Workers/Rpc/RpcWorkerConstants.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ public static class RpcWorkerConstants
public const string AcceptsListOfFunctionLoadRequests = "AcceptsListOfFunctionLoadRequests";
public const string EnableUserCodeException = "EnableUserCodeException";
public const string SupportsLoadResponseCollection = "SupportsLoadResponseCollection";
public const string HandlesInvocationCancelMessage = "HandlesInvocationCancelMessage";

// Host Capabilities
public const string V2Compatable = "V2Compatable";
Expand Down
128 changes: 122 additions & 6 deletions test/WebJobs.Script.Tests/Workers/Rpc/GrpcWorkerChannelTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ public class GrpcWorkerChannelTests : IDisposable
private readonly IFunctionDataCache _functionDataCache;
private readonly IOptions<WorkerConcurrencyOptions> _workerConcurrencyOptions;
private GrpcWorkerChannel _workerChannel;
private GrpcWorkerChannel _workerChannelwithMockEventManager;
private GrpcWorkerChannel _workerChannelWithMockEventManager;

public GrpcWorkerChannelTests()
{
Expand Down Expand Up @@ -110,7 +110,7 @@ public GrpcWorkerChannelTests()
_workerConcurrencyOptions);

_eventManagerMock.Setup(proxy => proxy.Publish(It.IsAny<OutboundGrpcEvent>())).Verifiable();
_workerChannelwithMockEventManager = new GrpcWorkerChannel(
_workerChannelWithMockEventManager = new GrpcWorkerChannel(
_workerId,
_eventManagerMock.Object,
_testWorkerConfig,
Expand Down Expand Up @@ -319,6 +319,121 @@ public async Task SendInvocationRequest_InputsTransferredOverSharedMemory()
Assert.True(traces.Any(m => string.Equals(m.FormattedMessage, _expectedLogMsg)));
}

[Fact]
public async Task SendInvocationRequest_SignalCancellation_SendsInvocationCancelRequest()
{
var cancellationWaitTimeMs = 3000;
var invocationId = Guid.NewGuid();
var expectedCancellationLog = $"Sending invocation cancel request for InvocationId {invocationId.ToString()}";

IDictionary<string, string> capabilities = new Dictionary<string, string>()
{
{ RpcWorkerConstants.HandlesInvocationCancelMessage, "1" }
};

var cts = new CancellationTokenSource();
cts.CancelAfter(cancellationWaitTimeMs);
var token = cts.Token;

var initTask = _workerChannel.StartWorkerProcessAsync(CancellationToken.None);
_testFunctionRpcService.PublishStartStreamEvent(_workerId);
_testFunctionRpcService.PublishWorkerInitResponseEvent(capabilities);
await initTask;
var scriptInvocationContext = GetTestScriptInvocationContext(invocationId, null, token);
await _workerChannel.SendInvocationRequest(scriptInvocationContext);

while (!token.IsCancellationRequested)
{
await Task.Delay(1000);
if (token.IsCancellationRequested)
{
break;
}
}

var traces = _logger.GetLogMessages();
Assert.True(traces.Any(m => string.Equals(m.FormattedMessage, expectedCancellationLog)));
}

[Fact]
public async Task SendInvocationRequest_CancellationAlreadyRequested_ResultSourceCancelled()
{
var cancellationWaitTimeMs = 3000;
var invocationId = Guid.NewGuid();
var expectedCancellationLog = "Cancellation has been requested, cancelling invocation request";

IDictionary<string, string> capabilities = new Dictionary<string, string>()
{
{ RpcWorkerConstants.HandlesInvocationCancelMessage, "1" }
};

var cts = new CancellationTokenSource();
cts.CancelAfter(cancellationWaitTimeMs);
var token = cts.Token;

var initTask = _workerChannel.StartWorkerProcessAsync(CancellationToken.None);
_testFunctionRpcService.PublishStartStreamEvent(_workerId);
_testFunctionRpcService.PublishWorkerInitResponseEvent(capabilities);
await initTask;

while (!token.IsCancellationRequested)
{
await Task.Delay(1000);
if (token.IsCancellationRequested)
{
break;
}
}

var resultSource = new TaskCompletionSource<ScriptInvocationResult>();
var scriptInvocationContext = GetTestScriptInvocationContext(invocationId, resultSource, token);
await _workerChannel.SendInvocationRequest(scriptInvocationContext);

var traces = _logger.GetLogMessages();
Assert.True(traces.Any(m => string.Equals(m.FormattedMessage, expectedCancellationLog)));
Assert.Equal(TaskStatus.Canceled, resultSource.Task.Status);
}

[Fact]
public async Task SendInvocationCancelRequest__WithWorkerCapability_PublishesOutboundEvent()
{
var invocationId = Guid.NewGuid();
var expectedCancellationLog = $"Sending invocation cancel request for InvocationId {invocationId.ToString()}";

IDictionary<string, string> capabilities = new Dictionary<string, string>()
{
{ RpcWorkerConstants.HandlesInvocationCancelMessage, "1" }
};

var initTask = _workerChannel.StartWorkerProcessAsync(CancellationToken.None);
_testFunctionRpcService.PublishStartStreamEvent(_workerId);
_testFunctionRpcService.PublishWorkerInitResponseEvent(capabilities);
await initTask;

var scriptInvocationContext = GetTestScriptInvocationContext(invocationId, null);
_workerChannel.SendInvocationCancel(invocationId.ToString());

var traces = _logger.GetLogMessages();
Assert.True(traces.Any(m => string.Equals(m.FormattedMessage, expectedCancellationLog)));
Assert.True(traces.Any(m => string.Equals(m.FormattedMessage, _expectedLogMsg)));
// The outbound log should happen twice: once for worker init request and once for the invocation cancel request
Assert.Equal(traces.Where(m => m.FormattedMessage.Equals(_expectedLogMsg)).Count(), 2);
}

[Fact]
public void SendInvocationCancelRequest_WithoutWorkerCapability_NoAction()
{
var invocationId = Guid.NewGuid();
var expectedCancellationLog = $"Sending invocation cancel request for InvocationId {invocationId.ToString()}";

var scriptInvocationContext = GetTestScriptInvocationContext(invocationId, null);
_workerChannel.SendInvocationCancel(invocationId.ToString());

var traces = _logger.GetLogMessages();
Assert.False(traces.Any(m => string.Equals(m.FormattedMessage, expectedCancellationLog)));
Assert.Equal(traces.Where(m => m.FormattedMessage.Equals(_expectedLogMsg)).Count(), 0);
}

[Fact]
public async Task Drain_Verify()
{
Expand Down Expand Up @@ -826,7 +941,7 @@ public async Task GetLatencies_DoesNot_StartTimer_WhenDynamicConcurrencyDisabled
public async Task SendInvocationRequest_ValidateTraceContext()
{
ScriptInvocationContext scriptInvocationContext = GetTestScriptInvocationContext(Guid.NewGuid(), null);
await _workerChannelwithMockEventManager.SendInvocationRequest(scriptInvocationContext);
await _workerChannelWithMockEventManager.SendInvocationRequest(scriptInvocationContext);
if (_testEnvironment.IsApplicationInsightsAgentEnabled())
{
_eventManagerMock.Verify(proxy => proxy.Publish(It.Is<OutboundGrpcEvent>(
Expand All @@ -853,7 +968,7 @@ public async Task SendInvocationRequest_ValidateTraceContext_SessionId()
activity.AddBaggage(ScriptConstants.LiveLogsSessionAIKey, sessionId);
activity.Start();
ScriptInvocationContext scriptInvocationContext = GetTestScriptInvocationContext(Guid.NewGuid(), null);
await _workerChannelwithMockEventManager.SendInvocationRequest(scriptInvocationContext);
await _workerChannelWithMockEventManager.SendInvocationRequest(scriptInvocationContext);
activity.Stop();
_eventManagerMock.Verify(p => p.Publish(It.Is<OutboundGrpcEvent>(grpcEvent => ValidateInvocationRequest(grpcEvent, sessionId))));
}
Expand Down Expand Up @@ -900,7 +1015,7 @@ private IEnumerable<FunctionMetadata> GetTestFunctionsList(string runtime)
};
}

private ScriptInvocationContext GetTestScriptInvocationContext(Guid invocationId, TaskCompletionSource<ScriptInvocationResult> resultSource)
private ScriptInvocationContext GetTestScriptInvocationContext(Guid invocationId, TaskCompletionSource<ScriptInvocationResult> resultSource, CancellationToken? token = null)
{
return new ScriptInvocationContext()
{
Expand All @@ -914,7 +1029,8 @@ private ScriptInvocationContext GetTestScriptInvocationContext(Guid invocationId
},
BindingData = new Dictionary<string, object>(),
Inputs = new List<(string name, DataType type, object val)>(),
ResultSource = resultSource
ResultSource = resultSource,
CancellationToken = token == null ? CancellationToken.None : (CancellationToken)token
};
}

Expand Down