diff --git a/src/WebJobs.Script.Grpc/Channel/GrpcWorkerChannel.cs b/src/WebJobs.Script.Grpc/Channel/GrpcWorkerChannel.cs index 3ee6ed52f5..c0c468226f 100644 --- a/src/WebJobs.Script.Grpc/Channel/GrpcWorkerChannel.cs +++ b/src/WebJobs.Script.Grpc/Channel/GrpcWorkerChannel.cs @@ -32,7 +32,6 @@ using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; using Yarp.ReverseProxy.Forwarder; - using static Microsoft.Azure.WebJobs.Script.Grpc.Messages.RpcLog.Types; using FunctionMetadata = Microsoft.Azure.WebJobs.Script.Description.FunctionMetadata; using MsgType = Microsoft.Azure.WebJobs.Script.Grpc.Messages.StreamingMessage.ContentOneofCase; @@ -51,6 +50,7 @@ internal partial class GrpcWorkerChannel : IRpcWorkerChannel, IDisposable private readonly List _workerStatusLatencyHistory = new List(); private readonly IOptions _workerConcurrencyOptions; private readonly WaitCallback _processInbound; + private readonly IInvocationMessageDispatcherFactory _messageDispatcherFactory; private readonly object _syncLock = new object(); private readonly object _metadataLock = new object(); private readonly Dictionary> _pendingActions = new(); @@ -65,7 +65,7 @@ internal partial class GrpcWorkerChannel : IRpcWorkerChannel, IDisposable private RpcWorkerChannelState _state; private IDictionary _functionLoadErrors = new Dictionary(); private IDictionary _metadataRequestErrors = new Dictionary(); - private ConcurrentDictionary _executingInvocations = new ConcurrentDictionary(); + private ConcurrentDictionary _executingInvocations = new(); private IDictionary> _functionInputBuffers = new ConcurrentDictionary>(); private ConcurrentDictionary> _workerStatusRequests = new ConcurrentDictionary>(); private List _inputLinks = new List(); @@ -137,6 +137,9 @@ internal GrpcWorkerChannel( _startLatencyMetric = metricsLogger?.LatencyEvent(string.Format(MetricEventNames.WorkerInitializeLatency, workerConfig.Description.Language, attemptCount)); _state = RpcWorkerChannelState.Default; + + // Temporary switch to allow fully testing new algorithm in production + _messageDispatcherFactory = GetProcessorFactory(); } private bool IsHttpProxyingWorker => _httpProxyEndpoint is not null; @@ -149,6 +152,23 @@ internal GrpcWorkerChannel( public RpcWorkerConfig WorkerConfig => _workerConfig; + // Temporary switch that allows us to move between the "old" ThreadPool-only processor + // and a "new" Channel processor (for proper ordering of messages). + private IInvocationMessageDispatcherFactory GetProcessorFactory() + { + if (_hostingConfigOptions.Value.EnableOrderedInvocationMessages || + FeatureFlags.IsEnabled(ScriptConstants.FeatureFlagEnableOrderedInvocationmessages, _environment)) + { + _workerChannelLogger.LogDebug($"Using {nameof(OrderedInvocationMessageDispatcherFactory)}."); + return new OrderedInvocationMessageDispatcherFactory(ProcessItem, _workerChannelLogger); + } + else + { + _workerChannelLogger.LogDebug($"Using {nameof(ThreadPoolInvocationProcessorFactory)}."); + return new ThreadPoolInvocationProcessorFactory(_processInbound); + } + } + private void ProcessItem(InboundGrpcEvent msg) { // note this method is a thread-pool (QueueUserWorkItem) entry-point @@ -251,7 +271,8 @@ private async Task ProcessInbound() { Logger.ChannelReceivedMessage(_workerChannelLogger, msg.WorkerId, msg.MessageType); } - ThreadPool.QueueUserWorkItem(_processInbound, msg); + + DispatchMessage(msg); } } } @@ -266,6 +287,40 @@ private async Task ProcessInbound() } } + private void DispatchMessage(InboundGrpcEvent msg) + { + // RpcLog and InvocationResponse messages are special. They need to be handled by the InvocationMessageDispatcher + switch (msg.MessageType) + { + case MsgType.RpcLog when msg.Message.RpcLog.LogCategory == RpcLogCategory.User || msg.Message.RpcLog.LogCategory == RpcLogCategory.CustomMetric: + if (_executingInvocations.TryGetValue(msg.Message.RpcLog.InvocationId, out var invocation)) + { + invocation.Dispatcher.DispatchRpcLog(msg); + } + else + { + // We received a log outside of a invocation + ThreadPool.QueueUserWorkItem(_processInbound, msg); + } + break; + case MsgType.InvocationResponse: + if (_executingInvocations.TryGetValue(msg.Message.InvocationResponse.InvocationId, out invocation)) + { + invocation.Dispatcher.DispatchInvocationResponse(msg); + } + else + { + // This should never happen, but if it does, just send it to the ThreadPool. + ThreadPool.QueueUserWorkItem(_processInbound, msg); + } + break; + default: + // All other messages can go to the thread pool. + ThreadPool.QueueUserWorkItem(_processInbound, msg); + break; + } + } + public bool IsChannelReadyForInvocations() { return !_disposing && !_disposed && _state.HasFlag(RpcWorkerChannelState.InvocationBuffersInitialized | RpcWorkerChannelState.Initialized); @@ -750,14 +805,14 @@ internal async Task SendInvocationRequest(ScriptInvocationContext context) { _workerChannelLogger.LogDebug("Function {functionName} failed to load", context.FunctionMetadata.Name); context.ResultSource.TrySetException(_functionLoadErrors[context.FunctionMetadata.GetFunctionId()]); - _executingInvocations.TryRemove(invocationId, out ScriptInvocationContext _); + RemoveExecutingInvocation(invocationId); return; } else if (_metadataRequestErrors.ContainsKey(context.FunctionMetadata.GetFunctionId())) { _workerChannelLogger.LogDebug("Worker failed to load metadata for {functionName}", context.FunctionMetadata.Name); context.ResultSource.TrySetException(_metadataRequestErrors[context.FunctionMetadata.GetFunctionId()]); - _executingInvocations.TryRemove(invocationId, out ScriptInvocationContext _); + RemoveExecutingInvocation(invocationId); return; } @@ -771,7 +826,7 @@ internal async Task SendInvocationRequest(ScriptInvocationContext context) var invocationRequest = await context.ToRpcInvocationRequest(_workerChannelLogger, _workerCapabilities, _isSharedMemoryDataTransferEnabled, _sharedMemoryManager); AddAdditionalTraceContext(invocationRequest.TraceContext.Attributes, context); - _executingInvocations.TryAdd(invocationRequest.InvocationId, context); + _executingInvocations.TryAdd(invocationRequest.InvocationId, new(context, _messageDispatcherFactory.Create(invocationRequest.InvocationId))); _metricsLogger.LogEvent(string.Format(MetricEventNames.WorkerInvoked, Id), functionName: Sanitizer.Sanitize(context.FunctionMetadata.Name)); await SendStreamingMessageAsync(new StreamingMessage @@ -980,8 +1035,9 @@ internal async Task InvokeResponse(InvocationResponse invokeResponse) // Check if the worker supports logging user-code-thrown exceptions to app insights bool capabilityEnabled = !string.IsNullOrEmpty(_workerCapabilities.GetCapabilityState(RpcWorkerConstants.EnableUserCodeException)); - if (_executingInvocations.TryRemove(invokeResponse.InvocationId, out ScriptInvocationContext context)) + if (_executingInvocations.TryRemove(invokeResponse.InvocationId, out var invocation)) { + var context = invocation.Context; if (invokeResponse.Result.IsInvocationSuccess(context.ResultSource, capabilityEnabled)) { _metricsLogger.LogEvent(string.Format(MetricEventNames.WorkerInvokeSucceeded, Id)); @@ -1051,6 +1107,8 @@ internal async Task InvokeResponse(InvocationResponse invokeResponse) SendCloseSharedMemoryResourcesForInvocationRequest(outputMaps); } } + + invocation.Dispose(); } else { @@ -1083,8 +1141,10 @@ internal void SendCloseSharedMemoryResourcesForInvocationRequest(IList o internal void Log(GrpcEvent msg) { var rpcLog = msg.Message.RpcLog; - if (_executingInvocations.TryGetValue(rpcLog.InvocationId, out ScriptInvocationContext context)) + if (_executingInvocations.TryGetValue(rpcLog.InvocationId, out var invocation)) { + var context = invocation.Context; + // Restore the execution context from the original invocation. This allows AsyncLocal state to flow to loggers. System.Threading.ExecutionContext.Run(context.AsyncExecutionContext, static (state) => { @@ -1273,12 +1333,25 @@ internal void ReceiveWorkerStatusResponse(string requestId, WorkerStatusResponse } } + private void RemoveExecutingInvocation(string invocationId) + { + if (_executingInvocations.TryRemove(invocationId, out var invocation)) + { + invocation.Dispose(); + } + } + protected virtual void Dispose(bool disposing) { if (!_disposed) { if (disposing) { + foreach (var id in _executingInvocations.Keys) + { + RemoveExecutingInvocation(id); + } + _startLatencyMetric?.Dispose(); _workerInitTask?.TrySetCanceled(); _timer?.Dispose(); @@ -1366,9 +1439,9 @@ private void StopWorkerProcess() public async Task DrainInvocationsAsync() { _workerChannelLogger.LogDebug("Count of in-buffer invocations waiting to be drained out: {invocationCount}", _executingInvocations.Count); - foreach (ScriptInvocationContext currContext in _executingInvocations.Values) + foreach (var invocation in _executingInvocations.Values) { - await currContext.ResultSource.Task; + await invocation.Context.ResultSource.Task; } } @@ -1384,12 +1457,12 @@ public bool TryFailExecutions(Exception workerException) return false; } - foreach (ScriptInvocationContext currContext in _executingInvocations?.Values) + foreach (var invocation in _executingInvocations?.Values) { - string invocationId = currContext?.ExecutionContext?.InvocationId.ToString(); + string invocationId = invocation.Context?.ExecutionContext?.InvocationId.ToString(); _workerChannelLogger.LogDebug("Worker '{workerId}' encountered a fatal error. Failing invocation: '{invocationId}'", _workerId, invocationId); - currContext?.ResultSource?.TrySetException(workerException); - _executingInvocations.TryRemove(invocationId, out ScriptInvocationContext _); + invocation.Context?.ResultSource?.TrySetException(workerException); + RemoveExecutingInvocation(invocationId); } return true; } @@ -1529,6 +1602,24 @@ private void AddAdditionalTraceContext(MapField attributes, Scri } } + private sealed class ExecutingInvocation : IDisposable + { + public ExecutingInvocation(ScriptInvocationContext context, IInvocationMessageDispatcher dispatcher) + { + Context = context; + Dispatcher = dispatcher; + } + + public ScriptInvocationContext Context { get; } + + public IInvocationMessageDispatcher Dispatcher { get; } + + public void Dispose() + { + (Dispatcher as IDisposable)?.Dispose(); + } + } + private sealed class PendingItem { private readonly Action _callback; diff --git a/src/WebJobs.Script.Grpc/Channel/IInvocationMessageDispatcher.cs b/src/WebJobs.Script.Grpc/Channel/IInvocationMessageDispatcher.cs new file mode 100644 index 0000000000..14b01fca3c --- /dev/null +++ b/src/WebJobs.Script.Grpc/Channel/IInvocationMessageDispatcher.cs @@ -0,0 +1,43 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +using Microsoft.Azure.WebJobs.Script.Grpc.Eventing; + +namespace Microsoft.Azure.WebJobs.Script.Grpc; + +/// +/// Interface for processing grpc messages that may come from the worker that are +/// related to an invocation (RpcLog and InvocationResponse). The contract here is as-follows: +/// - The contract with a worker is that, during an invocation (i.e an InvocationRequest has been sent), there are only +/// two messages that the worker can send us related to that invocation +/// - one or many RpcLog messages +/// - one final InvocationResponse that effectively ends the invocation. Once the InvocationResponse is received, no +/// more RpcLog messages from this specific invocation will be processed. +/// - The GrpcChannel is looping to dequeue grpc messages from a worker. When it finds one that is either an RpcLog +/// or an InvocationResponse, it will will call the matching method on this interface (i.e. RpcLog -> DispatchRpcLog) +/// from the same thread that is looping and dequeuing items from the grpc channel. The implementors of this interface +/// must quickly dispatch the message to a background Task or Thread for handling, so as to not block the +/// main loop from dequeuing more messages. +/// - Because the methods on this interface are all on being called from the same thread, they do not need to be +/// thread-safe. They can assume that they will not be called multiple times from different threads. +/// +internal interface IInvocationMessageDispatcher +{ + /// + /// Inspects the incoming RpcLog and dispatches to a Thread or background Task as quickly as possible. This method is + /// called from a loop processing incoming grpc messages and any thread blocking will delay the processing of that loop. + /// It can be assumed that this method will never be called from multiple threads simultaneously and thus does not need + /// to be thread-safe. + /// + /// The RpcLog message. Implementors can assume that this message is an RpcLog. + void DispatchRpcLog(InboundGrpcEvent msg); + + /// + /// Inspects an incoming InvocationResponse message and dispatches to a Thread or background Task as quickly as possible. + /// This method is called from a loop processing incoming grpc messages and any thread blocking will delay the processing + /// of that loop. It can be assumed that this method will never be called from multiple threads simultaneously and thus + /// does not need to be thread-safe. + /// + /// The InvocationResponse message. Implementors can assume that this message is an InvocationResponse. + void DispatchInvocationResponse(InboundGrpcEvent msg); +} \ No newline at end of file diff --git a/src/WebJobs.Script.Grpc/Channel/IInvocationMessageDispatcherFactory.cs b/src/WebJobs.Script.Grpc/Channel/IInvocationMessageDispatcherFactory.cs new file mode 100644 index 0000000000..dcc1d52cc2 --- /dev/null +++ b/src/WebJobs.Script.Grpc/Channel/IInvocationMessageDispatcherFactory.cs @@ -0,0 +1,9 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +namespace Microsoft.Azure.WebJobs.Script.Grpc; + +internal interface IInvocationMessageDispatcherFactory +{ + IInvocationMessageDispatcher Create(string invocationId); +} \ No newline at end of file diff --git a/src/WebJobs.Script.Grpc/Channel/OrderedInvocationMessageDispatcher.cs b/src/WebJobs.Script.Grpc/Channel/OrderedInvocationMessageDispatcher.cs new file mode 100644 index 0000000000..ec46cb70f3 --- /dev/null +++ b/src/WebJobs.Script.Grpc/Channel/OrderedInvocationMessageDispatcher.cs @@ -0,0 +1,160 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +using System; +using System.Threading; +using System.Threading.Channels; +using System.Threading.Tasks; +using Microsoft.Azure.WebJobs.Script.Grpc.Eventing; +using Microsoft.Extensions.Logging; + +namespace Microsoft.Azure.WebJobs.Script.Grpc; + +/// +/// An implementation of that internally uses a per-invocation to ensure +/// ordering of messages is maintained. The calls to and +/// both write messages to the same . Then, a background reads messages from this channel and invokes the +/// provided in the constructor. +/// +/// This dispatcher is created with the to ensure that instances are created and disposed +/// per-invocation. This means that every instance is only ever responsible for processing messages from a single invocation. +/// +internal sealed class OrderedInvocationMessageDispatcher : IInvocationMessageDispatcher, IDisposable +{ + private readonly ILogger _logger; + private readonly string _invocationId; + private readonly Action _processItemWithChannel; + + // Separated these out for easier testing of the flow. + private readonly Action _processItemWithThreadPool; + + private Channel _channel; + private bool _isChannelInitialized = false; + private bool _invocationComplete = false; + private bool _disposed = false; + + /// + /// Initializes a new instance of the class. + /// + /// The function invocation id. + /// The logger. + /// A callback to be invoked when processing an item. + public OrderedInvocationMessageDispatcher(string invocationId, ILogger logger, Action processItem) + : this(invocationId, logger, processItem, processItem) + { + } + + /// + /// Initializes a new instance of the class. This constructor is only for testing purposes. + /// It allows a different action to be called when using the fallback ThreadPool dispatch behavior. This allows the tests to validate that + /// the fallback is correctly invoked. + /// + /// The function invocation id. + /// The logger. + /// A callback to be invoked when processing an item from the internal Channel. + /// A callback to be invoked when processing an item on the ThreadPool. This is a fallback scenario. + internal OrderedInvocationMessageDispatcher(string invocationId, ILogger logger, Action processItemWithChannel, + Action processItemWithThreadPool) + { + _logger = logger; + _invocationId = invocationId; + _processItemWithChannel = processItemWithChannel; + _processItemWithThreadPool = processItemWithThreadPool; + } + + // For testing + internal Channel MessageChannel => _channel; + + private static Channel InitializeChannel() => + Channel.CreateUnbounded( + new UnboundedChannelOptions + { + SingleReader = true, + SingleWriter = true + }); + + public void DispatchRpcLog(InboundGrpcEvent msg) + { + // This is not thread-safe, but it is always called in-order from the same thread, so no + // need for locking. + + // The channel is only needed if we receive any RpcLog messages before we receive an + // InvocationResponse message. If the only message we ever see is InvocationResponse, we know + // the invocation is already complete and there's no need to use a channel for ordering, so skip + // the initialization altogether. In DispatchToInvocationResponse, we'll fallback to use the ThreadPool + // if the channel is not initialized. + + // We also do not want to initialize the channel if we're receiving an RpcLog after the invocation + // has completed. In this case, the RpcLog will be dropped anyway, so there's no need to maintain + // ordering with the channel. + if (!_isChannelInitialized && !_invocationComplete) + { + _channel = InitializeChannel(); + _ = ReadMessagesAsync(); + _isChannelInitialized = true; + } + + WriteToChannel(msg); + } + + public void DispatchInvocationResponse(InboundGrpcEvent msg) + { + // Any other messages that come here shouldn't use the Channel. + _invocationComplete = true; + + if (_isChannelInitialized) + { + WriteToChannel(msg); + + // Receiving an InvocationResponse signals that we're done with this invocation. + _channel.Writer.TryComplete(); + } + else + { + // Channel was never started. We must not have needed it. Send directly to ThreadPool. + DispatchToThreadPool(msg); + } + } + + private void WriteToChannel(InboundGrpcEvent msg) + { + if (_channel is null || !_channel.Writer.TryWrite(msg)) + { + // If this fails, fall back to the ThreadPool + _logger.LogDebug("Cannot write '{msgType}' to channel for InvocationId '{functionInvocationId}'. Dispatching message to the ThreadPool.", msg.MessageType, _invocationId); + DispatchToThreadPool(msg); + } + } + + private void DispatchToThreadPool(InboundGrpcEvent msg) => + ThreadPool.QueueUserWorkItem(state => _processItemWithThreadPool((InboundGrpcEvent)state), msg); + + private async Task ReadMessagesAsync() + { + try + { + await foreach (InboundGrpcEvent msg in _channel.Reader.ReadAllAsync()) + { + // Assume the Action being called is already wrapped in a try/catch + _processItemWithChannel(msg); + } + } + catch (Exception ex) + { + _logger.LogError(ex, "Error while reading InvocationProcessor channel for InvocationId '{functionInvocationId}'.", _invocationId); + + // Ensure nothing else will be written to the channel. There is a possibility + // that some messages are lost here. + _channel.Writer.TryComplete(ex); + } + } + + public void Dispose() + { + if (!_disposed) + { + _disposed = true; + _channel?.Writer.TryComplete(); + } + } +} \ No newline at end of file diff --git a/src/WebJobs.Script.Grpc/Channel/OrderedInvocationMessageDispatcherFactory.cs b/src/WebJobs.Script.Grpc/Channel/OrderedInvocationMessageDispatcherFactory.cs new file mode 100644 index 0000000000..880dcff08e --- /dev/null +++ b/src/WebJobs.Script.Grpc/Channel/OrderedInvocationMessageDispatcherFactory.cs @@ -0,0 +1,23 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +using System; +using Microsoft.Azure.WebJobs.Script.Grpc.Eventing; +using Microsoft.Extensions.Logging; + +namespace Microsoft.Azure.WebJobs.Script.Grpc; + +internal class OrderedInvocationMessageDispatcherFactory : IInvocationMessageDispatcherFactory +{ + private readonly Action _processItem; + private readonly ILogger _logger; + + public OrderedInvocationMessageDispatcherFactory(Action processItem, ILogger logger) + { + _processItem = processItem; + _logger = logger; + } + + public IInvocationMessageDispatcher Create(string invocationId) => + new OrderedInvocationMessageDispatcher(invocationId, _logger, _processItem); +} \ No newline at end of file diff --git a/src/WebJobs.Script.Grpc/Channel/ThreadPoolInvocationProcessorFactory.cs b/src/WebJobs.Script.Grpc/Channel/ThreadPoolInvocationProcessorFactory.cs new file mode 100644 index 0000000000..4ceabad16a --- /dev/null +++ b/src/WebJobs.Script.Grpc/Channel/ThreadPoolInvocationProcessorFactory.cs @@ -0,0 +1,28 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +using System.Threading; +using Microsoft.Azure.WebJobs.Script.Grpc.Eventing; + +namespace Microsoft.Azure.WebJobs.Script.Grpc; + +/// +/// Temporary class that handles both creation of the processor and acts as the processor itself. Once we've +/// confirmed that the OrderedInvocationMessageDispatcher works as expected, we will remove this. +/// +internal class ThreadPoolInvocationProcessorFactory : IInvocationMessageDispatcher, IInvocationMessageDispatcherFactory +{ + private readonly WaitCallback _callback; + + public ThreadPoolInvocationProcessorFactory(WaitCallback callback) + { + _callback = callback; + } + + // always return a single instance + public IInvocationMessageDispatcher Create(string invocationId) => this; + + public void DispatchRpcLog(InboundGrpcEvent msg) => ThreadPool.QueueUserWorkItem(_callback, msg); + + public void DispatchInvocationResponse(InboundGrpcEvent msg) => ThreadPool.QueueUserWorkItem(_callback, msg); +} \ No newline at end of file diff --git a/src/WebJobs.Script/Config/FunctionsHostingConfigOptions.cs b/src/WebJobs.Script/Config/FunctionsHostingConfigOptions.cs index 53a49e75e9..20996ac00b 100644 --- a/src/WebJobs.Script/Config/FunctionsHostingConfigOptions.cs +++ b/src/WebJobs.Script/Config/FunctionsHostingConfigOptions.cs @@ -113,6 +113,19 @@ public bool DisableLinuxAppServiceExecutionDetails } } + public bool EnableOrderedInvocationMessages + { + get + { + return GetFeature(ScriptConstants.FeatureFlagEnableOrderedInvocationmessages) == "1"; + } + + set + { + _features[ScriptConstants.FeatureFlagEnableOrderedInvocationmessages] = value ? "1" : "0"; + } + } + /// /// Gets the highest version of extension bundle v3 supported /// diff --git a/src/WebJobs.Script/ScriptConstants.cs b/src/WebJobs.Script/ScriptConstants.cs index 177d1ae155..0df6017078 100644 --- a/src/WebJobs.Script/ScriptConstants.cs +++ b/src/WebJobs.Script/ScriptConstants.cs @@ -128,6 +128,7 @@ public static class ScriptConstants public const string FeatureFlagEnableWorkerIndexing = "EnableWorkerIndexing"; public const string FeatureFlagEnableDebugTracing = "EnableDebugTracing"; public const string FeatureFlagEnableProxies = "EnableProxies"; + public const string FeatureFlagEnableOrderedInvocationmessages = "EnableOrderedInvocationMessages"; public const string HostingConfigDisableLinuxAppServiceDetailedExecutionEvents = "DisableLinuxExecutionDetails"; public const string HostingConfigDisableLinuxAppServiceExecutionEventLogBackoff = "DisableLinuxLogBackoff"; public const string FeatureFlagEnableLegacyDurableVersionCheck = "EnableLegacyDurableVersionCheck"; diff --git a/test/WebJobs.Script.Tests/Workers/Rpc/GrpcWorkerChannelTests.cs b/test/WebJobs.Script.Tests/Workers/Rpc/GrpcWorkerChannelTests.cs index 46e41f33e2..5457fbc9fd 100644 --- a/test/WebJobs.Script.Tests/Workers/Rpc/GrpcWorkerChannelTests.cs +++ b/test/WebJobs.Script.Tests/Workers/Rpc/GrpcWorkerChannelTests.cs @@ -1355,7 +1355,47 @@ public async Task GetFunctionMetadata_MultipleCalls_ReturnSameTask() Assert.Same(functionsTask1, functionsTask2); } - private IEnumerable GetTestFunctionsList(string runtime, bool addWorkerProperties = false) + [Fact] + public async Task Log_And_InvocationResult_OrderedCorrectly() + { + // Without this feature flag, this test fails every time on multi-core machines as the logs will + // be processed out-of-order + _testEnvironment.SetEnvironmentVariable(EnvironmentSettingNames.AzureWebJobsFeatureFlags, ScriptConstants.FeatureFlagEnableOrderedInvocationmessages); + + await CreateDefaultWorkerChannel(); + _metricsLogger.ClearCollections(); + + _logger.ClearLogMessages(); + + var invocationId = Guid.NewGuid(); + ScriptInvocationContext scriptInvocationContext = GetTestScriptInvocationContext(invocationId, new TaskCompletionSource(), logger: _logger); + await _workerChannel.SendInvocationRequest(scriptInvocationContext); + + int logLoop = 10; + for (int j = 0; j < logLoop; j++) + { + _testFunctionRpcService.PublishLogEvent($"{invocationId} {j}", invocationId.ToString()); + } + + _testFunctionRpcService.PublishInvocationResponseEvent(invocationId.ToString()); + + LogMessage[] GetInvocationLogs() + { + return _logger.GetLogMessages().Where(m => m.FormattedMessage.StartsWith(invocationId.ToString())).ToArray(); + } + + await TestHelpers.Await(() => GetInvocationLogs().Length == logLoop, + timeout: 3000, userMessageCallback: () => $"Expected {logLoop} logs. Received {GetInvocationLogs().Length}"); + + // ensure they came in the correct order + var logs = GetInvocationLogs(); + for (int i = 0; i < logLoop; i++) + { + Assert.EndsWith(i.ToString(), logs[i].FormattedMessage); + } + } + + private static IEnumerable GetTestFunctionsList(string runtime, bool addWorkerProperties = false) { var metadata1 = new FunctionMetadata() { @@ -1394,7 +1434,8 @@ private IEnumerable GetTestFunctionsList(string runtime, bool }; } - private ScriptInvocationContext GetTestScriptInvocationContext(Guid invocationId, TaskCompletionSource resultSource, CancellationToken? token = null) + public static ScriptInvocationContext GetTestScriptInvocationContext(Guid invocationId, TaskCompletionSource resultSource, + CancellationToken? token = null, ILogger logger = null, string scriptRootPath = null) { return new ScriptInvocationContext() { @@ -1403,13 +1444,15 @@ private ScriptInvocationContext GetTestScriptInvocationContext(Guid invocationId { InvocationId = invocationId, FunctionName = "js1", - FunctionAppDirectory = _scriptRootPath, - FunctionDirectory = _scriptRootPath + FunctionAppDirectory = scriptRootPath, + FunctionDirectory = scriptRootPath }, BindingData = new Dictionary(), Inputs = new List<(string Name, DataType Type, object Val)>(), ResultSource = resultSource, - CancellationToken = token == null ? CancellationToken.None : (CancellationToken)token + CancellationToken = token == null ? CancellationToken.None : (CancellationToken)token, + AsyncExecutionContext = System.Threading.ExecutionContext.Capture(), + Logger = logger }; } diff --git a/test/WebJobs.Script.Tests/Workers/Rpc/OrderedInvocationDispatcherTests.cs b/test/WebJobs.Script.Tests/Workers/Rpc/OrderedInvocationDispatcherTests.cs new file mode 100644 index 0000000000..b712295bcc --- /dev/null +++ b/test/WebJobs.Script.Tests/Workers/Rpc/OrderedInvocationDispatcherTests.cs @@ -0,0 +1,147 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +using System; +using System.Threading.Tasks; +using Microsoft.Azure.WebJobs.Script.Grpc; +using Microsoft.Azure.WebJobs.Script.Grpc.Eventing; +using Microsoft.Azure.WebJobs.Script.Grpc.Messages; +using Xunit; + +namespace Microsoft.Azure.WebJobs.Script.Tests.Workers.Rpc +{ + public class OrderedInvocationDispatcherTests : IDisposable + { + private readonly TestLogger _logger = new TestLogger(); + private int _channelCount; + private int _threadCount; + private OrderedInvocationMessageDispatcher _dispatcher; + + public OrderedInvocationDispatcherTests() + { + _dispatcher = new OrderedInvocationMessageDispatcher(Guid.NewGuid().ToString(), _logger, ProcessWithChannel, ProcessWithThreadPool); + } + + [Fact] + public async Task Processor_WithLogs_UsesChannel() + { + _dispatcher.DispatchRpcLog(CreateRpcLog()); + _dispatcher.DispatchRpcLog(CreateRpcLog()); + _dispatcher.DispatchRpcLog(CreateRpcLog()); + _dispatcher.DispatchInvocationResponse(CreateInvocationResponse()); + + await TestHelpers.Await(() => _channelCount == 4); + + Assert.Empty(_logger.GetLogMessages()); + Assert.Equal(0, _threadCount); + Assert.Equal(4, _channelCount); + Assert.NotNull(_dispatcher.MessageChannel); + Assert.True(_dispatcher.MessageChannel.Reader.Completion.IsCompletedSuccessfully); + } + + [Fact] + public async Task Processor_OnDispose_ClosesWriterAndSendsToThreadPool() + { + _dispatcher.DispatchRpcLog(CreateRpcLog()); + + await TestHelpers.Await(() => _channelCount == 1); + + _dispatcher.Dispose(); + + _dispatcher.DispatchInvocationResponse(CreateInvocationResponse()); + + // We still expect this to be run, but on the ThreadPool rather than via the channel. + await TestHelpers.Await(() => _threadCount == 1); + + Assert.Collection(_logger.GetLogMessages(), m => Assert.StartsWith("Cannot write", m.FormattedMessage)); + Assert.Equal(1, _channelCount); + Assert.Equal(1, _threadCount); + Assert.NotNull(_dispatcher.MessageChannel); + Assert.True(_dispatcher.MessageChannel.Reader.Completion.IsCompletedSuccessfully); + } + + [Fact] + public async Task Processor_NoLogs_DoesNotUseChannel() + { + _dispatcher.DispatchInvocationResponse(CreateInvocationResponse()); + + await TestHelpers.Await(() => _threadCount == 1); + + Assert.Equal(0, _channelCount); + Assert.Equal(1, _threadCount); + Assert.Null(_dispatcher.MessageChannel); + } + + [Fact] + public async Task Processor_RpcLogAfterChannelCloses_UsesThreadPool() + { + // Use an RpcLog to initialize the channel, then close it, then try to log again. + _dispatcher.DispatchRpcLog(CreateRpcLog()); + _dispatcher.DispatchInvocationResponse(CreateInvocationResponse()); + _dispatcher.DispatchRpcLog(CreateRpcLog()); + + await TestHelpers.Await(() => _channelCount == 2); + await TestHelpers.Await(() => _threadCount == 1); + + Assert.Equal(2, _channelCount); + Assert.Equal(1, _threadCount); + Assert.Collection(_logger.GetLogMessages(), m => Assert.StartsWith("Cannot write", m.FormattedMessage)); + Assert.NotNull(_dispatcher.MessageChannel); + Assert.True(_dispatcher.MessageChannel.Reader.Completion.IsCompletedSuccessfully); + } + + [Fact] + public async Task Processor_RpcLogAfterResponse_UsesThreadPool() + { + // If the Channel was never initialized and we've received an after-completion RpcLog, + // do not initialize the channel. Should fall back to ThreadPool and log. + _dispatcher.DispatchInvocationResponse(CreateInvocationResponse()); + _dispatcher.DispatchRpcLog(CreateRpcLog()); + + await TestHelpers.Await(() => _threadCount == 2); + + Assert.Equal(0, _channelCount); + Assert.Equal(2, _threadCount); + Assert.Collection(_logger.GetLogMessages(), m => Assert.StartsWith("Cannot write", m.FormattedMessage)); + Assert.Null(_dispatcher.MessageChannel); + } + + private static InboundGrpcEvent CreateRpcLog() + { + var msg = new StreamingMessage + { + RpcLog = new RpcLog + { + Message = "test" + } + }; + + return new InboundGrpcEvent("worker_id", msg); + } + + private static InboundGrpcEvent CreateInvocationResponse() + { + var msg = new StreamingMessage + { + InvocationResponse = new InvocationResponse() + }; + + return new InboundGrpcEvent("worker_id", msg); + } + + private void ProcessWithChannel(InboundGrpcEvent msg) + { + _channelCount++; + } + + private void ProcessWithThreadPool(InboundGrpcEvent msg) + { + _threadCount++; + } + + public void Dispose() + { + _dispatcher.Dispose(); + } + } +} diff --git a/test/WebJobs.Script.Tests/Workers/Rpc/TestFunctionRpcService.cs b/test/WebJobs.Script.Tests/Workers/Rpc/TestFunctionRpcService.cs index c014721174..d3d6c11fae 100644 --- a/test/WebJobs.Script.Tests/Workers/Rpc/TestFunctionRpcService.cs +++ b/test/WebJobs.Script.Tests/Workers/Rpc/TestFunctionRpcService.cs @@ -7,7 +7,6 @@ using System.Threading.Channels; using System.Threading.Tasks; using Microsoft.Azure.WebJobs.Script.Description; -using Microsoft.Azure.WebJobs.Script.Diagnostics; using Microsoft.Azure.WebJobs.Script.Eventing; using Microsoft.Azure.WebJobs.Script.Grpc.Eventing; using Microsoft.Azure.WebJobs.Script.Grpc.Messages; @@ -270,6 +269,24 @@ public void PublishSystemLogEvent(RpcLog.Types.Level inputLevel) Write(logMessage); } + public void PublishLogEvent(string message, string invocationId) + { + RpcLog rpcLog = new RpcLog() + { + LogCategory = RpcLog.Types.RpcLogCategory.User, + Level = RpcLog.Types.Level.Information, + InvocationId = invocationId, + Message = message + }; + + StreamingMessage logMessage = new StreamingMessage() + { + RpcLog = rpcLog + }; + + Write(logMessage); + } + public static FunctionEnvironmentReloadResponse GetTestFunctionEnvReloadResponse() { StatusResult statusResult = new StatusResult()