diff --git a/src/WebJobs.Script.Grpc/Channel/GrpcWorkerChannel.cs b/src/WebJobs.Script.Grpc/Channel/GrpcWorkerChannel.cs index 8181e6dd39..46769bd2f5 100644 --- a/src/WebJobs.Script.Grpc/Channel/GrpcWorkerChannel.cs +++ b/src/WebJobs.Script.Grpc/Channel/GrpcWorkerChannel.cs @@ -11,6 +11,7 @@ using System.Reactive.Linq; using System.Text; using System.Threading; +using System.Threading.Channels; using System.Threading.Tasks; using System.Threading.Tasks.Dataflow; using Google.Protobuf.Collections; @@ -28,7 +29,7 @@ using Microsoft.Azure.WebJobs.Script.Workers.SharedMemoryDataTransfer; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; -using Newtonsoft.Json.Linq; + using static Microsoft.Azure.WebJobs.Script.Grpc.Messages.RpcLog.Types; using FunctionMetadata = Microsoft.Azure.WebJobs.Script.Description.FunctionMetadata; using MsgType = Microsoft.Azure.WebJobs.Script.Grpc.Messages.StreamingMessage.ContentOneofCase; @@ -46,6 +47,11 @@ internal class GrpcWorkerChannel : IRpcWorkerChannel, IDisposable private readonly ISharedMemoryManager _sharedMemoryManager; private readonly List _workerStatusLatencyHistory = new List(); private readonly IOptions _workerConcurrencyOptions; + private readonly WaitCallback _processInbound; + private readonly object _syncLock = new object(); + private readonly Dictionary> _pendingActions = new (); + private readonly ChannelWriter _outbound; + private readonly ChannelReader _inbound; private IDisposable _functionLoadRequestResponseEvent; private bool _disposed; @@ -58,10 +64,8 @@ internal class GrpcWorkerChannel : IRpcWorkerChannel, IDisposable private ConcurrentDictionary _executingInvocations = new ConcurrentDictionary(); private IDictionary> _functionInputBuffers = new ConcurrentDictionary>(); private ConcurrentDictionary> _workerStatusRequests = new ConcurrentDictionary>(); - private IObservable _inboundWorkerEvents; private List _inputLinks = new List(); private List _eventSubscriptions = new List(); - private IDisposable _startSubscription; private IDisposable _startLatencyMetric; private IEnumerable _functions; private GrpcCapabilities _workerCapabilities; @@ -75,7 +79,6 @@ internal class GrpcWorkerChannel : IRpcWorkerChannel, IDisposable private bool _isSharedMemoryDataTransferEnabled; private bool _cancelCapabilityEnabled; - private object _syncLock = new object(); private System.Timers.Timer _timer; internal GrpcWorkerChannel( @@ -103,31 +106,24 @@ internal GrpcWorkerChannel( _applicationHostOptions = applicationHostOptions; _sharedMemoryManager = sharedMemoryManager; _workerConcurrencyOptions = workerConcurrencyOptions; + _processInbound = state => ProcessItem((InboundGrpcEvent)state); _workerCapabilities = new GrpcCapabilities(_workerChannelLogger); - _inboundWorkerEvents = _eventManager.OfType() - .Where(msg => msg.WorkerId == _workerId); - - _eventSubscriptions.Add(_inboundWorkerEvents - .Where(msg => msg.IsMessageOfType(MsgType.RpcLog) && !msg.IsLogOfCategory(RpcLogCategory.System)) - .Subscribe(Log)); + if (!_eventManager.TryGetGrpcChannels(workerId, out var inbound, out var outbound)) + { + throw new InvalidOperationException("Could not get gRPC channels for worker ID: " + workerId); + } - _eventSubscriptions.Add(_inboundWorkerEvents - .Where(msg => msg.IsMessageOfType(MsgType.RpcLog) && msg.IsLogOfCategory(RpcLogCategory.System)) - .Subscribe(SystemLog)); + _outbound = outbound.Writer; + _inbound = inbound.Reader; + // note: we don't start the read loop until StartWorkerProcessAsync is called _eventSubscriptions.Add(_eventManager.OfType() .Where(msg => _workerConfig.Description.Extensions.Contains(Path.GetExtension(msg.FileChangeArguments.FullPath))) .Throttle(TimeSpan.FromMilliseconds(300)) // debounce .Subscribe(msg => _eventManager.Publish(new HostRestartEvent()))); - _eventSubscriptions.Add(_inboundWorkerEvents.Where(msg => msg.MessageType == MsgType.InvocationResponse) - .Subscribe(async (msg) => await InvokeResponse(msg.Message.InvocationResponse))); - - _inboundWorkerEvents.Where(msg => msg.MessageType == MsgType.WorkerStatusResponse) - .Subscribe((msg) => ReceiveWorkerStatusResponse(msg.Message.RequestId, msg.Message.WorkerStatusResponse)); - _startLatencyMetric = metricsLogger?.LatencyEvent(string.Format(MetricEventNames.WorkerInitializeLatency, workerConfig.Description.Language, attemptCount)); _state = RpcWorkerChannelState.Default; @@ -141,6 +137,123 @@ internal GrpcWorkerChannel( internal RpcWorkerConfig Config => _workerConfig; + private void ProcessItem(InboundGrpcEvent msg) + { + // note this method is a thread-pool (QueueUserWorkItem) entry-point + try + { + switch (msg.MessageType) + { + case MsgType.RpcLog when msg.Message.RpcLog.LogCategory == RpcLogCategory.System: + SystemLog(msg); + break; + case MsgType.RpcLog: + Log(msg); + break; + case MsgType.WorkerStatusResponse: + ReceiveWorkerStatusResponse(msg.Message.RequestId, msg.Message.WorkerStatusResponse); + break; + case MsgType.InvocationResponse: + _ = InvokeResponse(msg.Message.InvocationResponse); + break; + default: + ProcessRegisteredGrpcCallbacks(msg); + break; + } + } + catch (Exception ex) + { + _workerChannelLogger.LogError(ex, "Error processing InboundGrpcEvent: " + ex.Message); + } + } + + private void ProcessRegisteredGrpcCallbacks(InboundGrpcEvent message) + { + Queue queue; + lock (_pendingActions) + { + if (!_pendingActions.TryGetValue(message.MessageType, out queue)) + { + return; // nothing to do + } + } + PendingItem next; + lock (queue) + { + do + { + if (!queue.TryDequeue(out next)) + { + return; // nothing to do + } + } + while (next.IsComplete); + } + next.SetResult(message); + } + + private void RegisterCallbackForNextGrpcMessage(MsgType messageType, TimeSpan timeout, int count, Action callback, Action faultHandler) + { + Queue queue; + lock (_pendingActions) + { + if (!_pendingActions.TryGetValue(messageType, out queue)) + { + queue = new Queue(); + _pendingActions.Add(messageType, queue); + } + } + + lock (queue) + { + // while we have the lock, discard any dead items (to prevent unbounded growth on stall) + while (queue.TryPeek(out var next) && next.IsComplete) + { + queue.Dequeue(); + } + for (int i = 0; i < count; i++) + { + var newItem = (i == count - 1) && (timeout != TimeSpan.Zero) + ? new PendingItem(callback, faultHandler, timeout) + : new PendingItem(callback, faultHandler); + queue.Enqueue(newItem); + } + } + } + + private async Task ProcessInbound() + { + try + { + await Task.Yield(); // free up the caller + bool debug = _workerChannelLogger.IsEnabled(LogLevel.Debug); + if (debug) + { + _workerChannelLogger.LogDebug("[channel] processing reader loop for worker {0}:", _workerId); + } + while (await _inbound.WaitToReadAsync()) + { + while (_inbound.TryRead(out var msg)) + { + if (debug) + { + _workerChannelLogger.LogDebug("[channel] received {0}: {1}", msg.WorkerId, msg.MessageType); + } + ThreadPool.QueueUserWorkItem(_processInbound, msg); + } + } + } + catch (Exception ex) + { + _workerChannelLogger.LogError(ex, "Error processing inbound messages"); + } + finally + { + // we're not listening any more! shut down the channels + _eventManager.RemoveGrpcChannels(_workerId); + } + } + public bool IsChannelReadyForInvocations() { return !_disposing && !_disposed && _state.HasFlag(RpcWorkerChannelState.InvocationBuffersInitialized | RpcWorkerChannelState.Initialized); @@ -148,10 +261,9 @@ public bool IsChannelReadyForInvocations() public async Task StartWorkerProcessAsync(CancellationToken cancellationToken) { - _startSubscription = _inboundWorkerEvents.Where(msg => msg.MessageType == MsgType.StartStream) - .Timeout(_workerConfig.CountOptions.ProcessStartupTimeout) - .Take(1) - .Subscribe(SendWorkerInitRequest, HandleWorkerStartStreamError); + RegisterCallbackForNextGrpcMessage(MsgType.StartStream, _workerConfig.CountOptions.ProcessStartupTimeout, 1, SendWorkerInitRequest, HandleWorkerStartStreamError); + // note: it is important that the ^^^ StartStream is in place *before* we start process the loop, otherwise we get a race condition + _ = ProcessInbound(); _workerChannelLogger.LogDebug("Initiating Worker Process start up"); await _rpcWorkerProcess.StartProcessAsync(); @@ -178,7 +290,7 @@ public async Task GetWorkerStatusAsync() var tcs = new TaskCompletionSource(); if (_workerStatusRequests.TryAdd(message.RequestId, tcs)) { - SendStreamingMessage(message); + await SendStreamingMessageAsync(message); await tcs.Task; var elapsed = sw.GetElapsedTime(); workerStatus.Latency = elapsed; @@ -199,10 +311,7 @@ public async Task GetWorkerStatusAsync() internal void SendWorkerInitRequest(GrpcEvent startEvent) { _workerChannelLogger.LogDebug("Worker Process started. Received StartStream message"); - _inboundWorkerEvents.Where(msg => msg.MessageType == MsgType.WorkerInitResponse) - .Timeout(_workerConfig.CountOptions.InitializationTimeout) - .Take(1) - .Subscribe(WorkerInitResponse, HandleWorkerInitError); + RegisterCallbackForNextGrpcMessage(MsgType.WorkerInitResponse, _workerConfig.CountOptions.InitializationTimeout, 1, WorkerInitResponse, HandleWorkerInitError); WorkerInitRequest initRequest = GetWorkerInitRequest(); @@ -221,6 +330,11 @@ internal void SendWorkerInitRequest(GrpcEvent startEvent) initRequest.Capabilities.Add(RpcWorkerConstants.FunctionDataCache, "true"); } + // advertise that we support multiple streams, and hint at a number; with this flag, we allow + // clients to connect multiple back-hauls *with the same workerid*, and rely on the internal + // plumbing to make sure we don't process everything N times + initRequest.Capabilities.Add(RpcWorkerConstants.MultiStream, "10"); // TODO: make this configurable + SendStreamingMessage(new StreamingMessage { WorkerInitRequest = initRequest @@ -269,7 +383,7 @@ internal void WorkerInitResponse(GrpcEvent initEvent) if (_initMessage.Result.IsFailure(out Exception exc)) { HandleWorkerInitError(exc); - _workerInitTask.SetResult(false); + _workerInitTask.TrySetResult(false); return; } @@ -284,7 +398,7 @@ internal void WorkerInitResponse(GrpcEvent initEvent) ScriptHost.IsFunctionDataCacheEnabled = false; } - _workerInitTask.SetResult(true); + _workerInitTask.TrySetResult(true); } public void SetupFunctionInvocationBuffers(IEnumerable functions) @@ -308,31 +422,23 @@ public void SendFunctionLoadRequests(ManagedDependencyOptions managedDependencyO // Check if the worker supports this feature bool capabilityEnabled = !string.IsNullOrEmpty(_workerCapabilities.GetCapabilityState(RpcWorkerConstants.SupportsLoadResponseCollection)); - if (capabilityEnabled) + TimeSpan timeout = TimeSpan.Zero; + if (functionTimeout.HasValue) { - var loadResponseCollectionObservable = _inboundWorkerEvents.Where(msg => msg.MessageType == MsgType.FunctionLoadResponseCollection); - if (functionTimeout.HasValue) - { - _functionLoadTimeout = functionTimeout.Value > _functionLoadTimeout ? functionTimeout.Value : _functionLoadTimeout; - loadResponseCollectionObservable = loadResponseCollectionObservable.Timeout(_functionLoadTimeout); - } + _functionLoadTimeout = functionTimeout.Value > _functionLoadTimeout ? functionTimeout.Value : _functionLoadTimeout; + timeout = _functionLoadTimeout; + } - _eventSubscriptions.Add(loadResponseCollectionObservable - .Subscribe((msg) => LoadResponse(msg.Message.FunctionLoadResponseCollection), HandleWorkerFunctionLoadError)); + var count = _functions.Count(); + if (capabilityEnabled) + { + RegisterCallbackForNextGrpcMessage(MsgType.FunctionLoadResponseCollection, timeout, count, msg => LoadResponse(msg.Message.FunctionLoadResponseCollection), HandleWorkerFunctionLoadError); SendFunctionLoadRequestCollection(_functions, managedDependencyOptions); } else { - var loadResponseObservable = _inboundWorkerEvents.Where(msg => msg.MessageType == MsgType.FunctionLoadResponse); - if (functionTimeout.HasValue) - { - _functionLoadTimeout = functionTimeout.Value > _functionLoadTimeout ? functionTimeout.Value : _functionLoadTimeout; - loadResponseObservable = loadResponseObservable.Timeout(_functionLoadTimeout); - } - - _eventSubscriptions.Add(loadResponseObservable.Take(_functions.Count()) - .Subscribe((msg) => LoadResponse(msg.Message.FunctionLoadResponse), HandleWorkerFunctionLoadError)); + RegisterCallbackForNextGrpcMessage(MsgType.FunctionLoadResponse, timeout, count, msg => LoadResponse(msg.Message.FunctionLoadResponse), HandleWorkerFunctionLoadError); foreach (FunctionMetadata metadata in _functions) { @@ -375,11 +481,8 @@ public Task SendFunctionEnvironmentReloadRequest() _workerChannelLogger.LogDebug("Sending FunctionEnvironmentReloadRequest to WorkerProcess with Pid: '{0}'", _rpcWorkerProcess.Id); IDisposable latencyEvent = _metricsLogger.LatencyEvent(MetricEventNames.SpecializationEnvironmentReloadRequestResponse); - _eventSubscriptions - .Add(_inboundWorkerEvents.Where(msg => msg.MessageType == MsgType.FunctionEnvironmentReloadResponse) - .Timeout(_workerConfig.CountOptions.EnvironmentReloadTimeout) - .Take(1) - .Subscribe((msg) => FunctionEnvironmentReloadResponse(msg.Message.FunctionEnvironmentReloadResponse, latencyEvent), HandleWorkerEnvReloadError)); + RegisterCallbackForNextGrpcMessage(MsgType.FunctionEnvironmentReloadResponse, _workerConfig.CountOptions.EnvironmentReloadTimeout, 1, + msg => FunctionEnvironmentReloadResponse(msg.Message.FunctionEnvironmentReloadResponse, latencyEvent), HandleWorkerEnvReloadError); IDictionary processEnv = Environment.GetEnvironmentVariables(); @@ -529,7 +632,7 @@ internal async Task SendInvocationRequest(ScriptInvocationContext context) context.CancellationToken.Register(() => SendInvocationCancel(invocationRequest.InvocationId)); } - SendStreamingMessage(new StreamingMessage + await SendStreamingMessageAsync(new StreamingMessage { InvocationRequest = invocationRequest }); @@ -563,10 +666,8 @@ public Task> GetFunctionMetadata() internal Task> SendFunctionMetadataRequest() { - _eventSubscriptions.Add(_inboundWorkerEvents.Where(msg => msg.MessageType == MsgType.FunctionMetadataResponse) - .Timeout(_functionLoadTimeout) - .Take(1) - .Subscribe((msg) => ProcessFunctionMetadataResponses(msg.Message.FunctionMetadataResponse), HandleWorkerMetadataRequestError)); + RegisterCallbackForNextGrpcMessage(MsgType.FunctionMetadataResponse, _functionLoadTimeout, 1, + msg => ProcessFunctionMetadataResponses(msg.Message.FunctionMetadataResponse), HandleWorkerMetadataRequestError); _workerChannelLogger.LogDebug("Sending WorkerMetadataRequest to {language} worker with worker ID {workerID}", _runtime, _workerId); @@ -865,7 +966,7 @@ internal void HandleWorkerFunctionLoadError(Exception exc) private void PublishWorkerErrorEvent(Exception exc) { - _workerInitTask.SetException(exc); + _workerInitTask.TrySetException(exc); if (_disposing || _disposed) { return; @@ -883,9 +984,45 @@ internal void HandleWorkerMetadataRequestError(Exception exc) _eventManager.Publish(new WorkerErrorEvent(_runtime, Id, exc)); } + private ValueTask SendStreamingMessageAsync(StreamingMessage msg) + { + var evt = new OutboundGrpcEvent(_workerId, msg); + return _outbound.TryWrite(evt) ? default : _outbound.WriteAsync(evt); + } + private void SendStreamingMessage(StreamingMessage msg) { - _eventManager.Publish(new OutboundGrpcEvent(_workerId, msg)); + var evt = new OutboundGrpcEvent(_workerId, msg); + if (!_outbound.TryWrite(evt)) + { + var pending = _outbound.WriteAsync(evt); + if (pending.IsCompleted) + { + try + { + pending.GetAwaiter().GetResult(); // ensure observed to ensure the IValueTaskSource completed/result is consumed + } + catch + { + // suppress failure + } + } + else + { + _ = ObserveEventually(pending); + } + } + static async Task ObserveEventually(ValueTask valueTask) + { + try + { + await valueTask; + } + catch + { + // no where to log + } + } } internal void ReceiveWorkerStatusResponse(string requestId, WorkerStatusResponse response) @@ -903,7 +1040,6 @@ protected virtual void Dispose(bool disposing) if (disposing) { _startLatencyMetric?.Dispose(); - _startSubscription?.Dispose(); _workerInitTask?.TrySetCanceled(); _timer?.Dispose(); @@ -919,6 +1055,9 @@ protected virtual void Dispose(bool disposing) { sub.Dispose(); } + + // shut down the channels + _eventManager.RemoveGrpcChannels(_workerId); } _disposed = true; } @@ -1070,7 +1209,14 @@ internal async void OnTimer(object sender, System.Timers.ElapsedEventArgs e) // Don't allow background execptions to escape // E.g. when a rpc channel is shutting down we can process exceptions } - _timer.Start(); + try + { + _timer.Start(); + } + catch (ObjectDisposedException) + { + // Specifically ignore this race - we're exiting and that's okay + } } private void AddSample(List samples, T sample) @@ -1106,5 +1252,75 @@ private void AddAdditionalTraceContext(MapField attributes, Scri } } } + + private sealed class PendingItem + { + private readonly Action _callback; + private readonly Action _faultHandler; + private CancellationTokenRegistration _ctr; + private int _state; + + public PendingItem(Action callback, Action faultHandler) + { + _callback = callback; + _faultHandler = faultHandler; + } + + public PendingItem(Action callback, Action faultHandler, TimeSpan timeout) + : this(callback, faultHandler) + { + var cts = new CancellationTokenSource(); + cts.CancelAfter(timeout); + _ctr = cts.Token.Register(static state => ((PendingItem)state).OnTimeout(), this); + } + + public bool IsComplete => Volatile.Read(ref _state) != 0; + + private bool MakeComplete() => Interlocked.CompareExchange(ref _state, 1, 0) == 0; + + public void SetResult(InboundGrpcEvent message) + { + _ctr.Dispose(); + _ctr = default; + if (MakeComplete() && _callback != null) + { + try + { + _callback.Invoke(message); + } + catch (Exception fault) + { + try + { + _faultHandler?.Invoke(fault); + } + catch + { + } + } + } + } + + private void OnTimeout() + { + if (MakeComplete() && _faultHandler != null) + { + try + { + throw new TimeoutException(); + } + catch (Exception timeout) + { + try + { + _faultHandler(timeout); + } + catch + { + } + } + } + } + } } } \ No newline at end of file diff --git a/src/WebJobs.Script.Grpc/Channel/GrpcWorkerChannelFactory.cs b/src/WebJobs.Script.Grpc/Channel/GrpcWorkerChannelFactory.cs index ea17b97913..09ed942576 100644 --- a/src/WebJobs.Script.Grpc/Channel/GrpcWorkerChannelFactory.cs +++ b/src/WebJobs.Script.Grpc/Channel/GrpcWorkerChannelFactory.cs @@ -7,6 +7,7 @@ using System.Reactive.Linq; using Microsoft.Azure.WebJobs.Script.Diagnostics; using Microsoft.Azure.WebJobs.Script.Eventing; +using Microsoft.Azure.WebJobs.Script.Grpc.Eventing; using Microsoft.Azure.WebJobs.Script.Workers; using Microsoft.Azure.WebJobs.Script.Workers.Rpc; using Microsoft.Azure.WebJobs.Script.Workers.SharedMemoryDataTransfer; @@ -47,6 +48,7 @@ public IRpcWorkerChannel Create(string scriptRootPath, string runtime, IMetricsL throw new InvalidOperationException($"WorkerCofig for runtime: {runtime} not found"); } string workerId = Guid.NewGuid().ToString(); + _eventManager.AddGrpcChannels(workerId); // prepare the inbound/outbound dedicated channels ILogger workerLogger = _loggerFactory.CreateLogger($"Worker.LanguageWorkerChannel.{runtime}.{workerId}"); IWorkerProcess rpcWorkerProcess = _rpcWorkerProcessFactory.Create(workerId, runtime, scriptRootPath, languageWorkerConfig); return new GrpcWorkerChannel( diff --git a/src/WebJobs.Script.Grpc/Eventing/GrpcEventExtensions.cs b/src/WebJobs.Script.Grpc/Eventing/GrpcEventExtensions.cs new file mode 100644 index 0000000000..3a7a587bad --- /dev/null +++ b/src/WebJobs.Script.Grpc/Eventing/GrpcEventExtensions.cs @@ -0,0 +1,67 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +using System; +using System.Threading.Channels; +using Microsoft.Azure.WebJobs.Script.Eventing; + +namespace Microsoft.Azure.WebJobs.Script.Grpc.Eventing; + +internal static class GrpcEventExtensions +{ + // flow here is: + // 1) external request is proxied to the the GrpcWorkerChannel via one of the many Send* APIs, which writes + // to outbound-writer; this means we can have concurrent writes to outbound + // 2) if an out-of-process function is connected, a FunctionRpcService-EventStream will consume + // from outbound-reader (we'll allow for the multi-stream possibility, hence concurrent), and push it via gRPC + // 3) when the out-of-process function provides a response to FunctionRpcService-EventStream, it is written to + // inbound-writer (note we will allow for multi-stream possibility) + // 4) the GrpcWorkerChannel has a single dedicated consumer of inbound-reader, which it then marries to + // in-flight operations + internal static readonly UnboundedChannelOptions InboundOptions = new UnboundedChannelOptions + { + SingleReader = true, // see 4 + SingleWriter = false, // see 3 + AllowSynchronousContinuations = false, + }; + + internal static readonly UnboundedChannelOptions OutboundOptions = new UnboundedChannelOptions + { + SingleReader = false, // see 2 + SingleWriter = false, // see 1 + AllowSynchronousContinuations = false, + }; + + public static void AddGrpcChannels(this IScriptEventManager manager, string workerId) + { + var inbound = Channel.CreateUnbounded(InboundOptions); + if (manager.TryAddWorkerState(workerId, inbound)) + { + var outbound = Channel.CreateUnbounded(OutboundOptions); + if (manager.TryAddWorkerState(workerId, outbound)) + { + return; // successfully added both + } + // we added the inbound but not the outbound; revert + manager.TryRemoveWorkerState(workerId, out inbound); + } + // this is not anticipated, so don't panic abount the allocs above + throw new ArgumentException("Duplicate worker id: " + workerId, nameof(workerId)); + } + + public static bool TryGetGrpcChannels(this IScriptEventManager manager, string workerId, out Channel inbound, out Channel outbound) + => manager.TryGetWorkerState(workerId, out inbound) & manager.TryGetWorkerState(workerId, out outbound); + + public static void RemoveGrpcChannels(this IScriptEventManager manager, string workerId) + { + // remove any channels, and shut them down + if (manager.TryGetWorkerState>(workerId, out var inbound)) + { + inbound.Writer.TryComplete(); + } + if (manager.TryGetWorkerState>(workerId, out var outbound)) + { + outbound.Writer.TryComplete(); + } + } +} diff --git a/src/WebJobs.Script.Grpc/Extensions/InboundGrpcEventExtensions.cs b/src/WebJobs.Script.Grpc/Extensions/InboundGrpcEventExtensions.cs deleted file mode 100644 index a1ae865b08..0000000000 --- a/src/WebJobs.Script.Grpc/Extensions/InboundGrpcEventExtensions.cs +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the MIT License. See License.txt in the project root for license information. - -using Microsoft.Azure.WebJobs.Script.Grpc.Eventing; -using static Microsoft.Azure.WebJobs.Script.Grpc.Messages.RpcLog.Types; -using MessageType = Microsoft.Azure.WebJobs.Script.Grpc.Messages.StreamingMessage.ContentOneofCase; - -namespace Microsoft.Azure.WebJobs.Script.Grpc -{ - public static class InboundGrpcEventExtensions - { - public static bool IsMessageOfType(this InboundGrpcEvent inboundEvent, MessageType typeToCheck) - { - return inboundEvent.MessageType.Equals(typeToCheck); - } - - public static bool IsLogOfCategory(this InboundGrpcEvent inboundEvent, RpcLogCategory categoryToCheck) - { - return inboundEvent.Message.RpcLog.LogCategory.Equals(categoryToCheck); - } - } -} diff --git a/src/WebJobs.Script.Grpc/Server/FunctionRpcService.cs b/src/WebJobs.Script.Grpc/Server/FunctionRpcService.cs index fd9bcfde0f..e69c21aad5 100644 --- a/src/WebJobs.Script.Grpc/Server/FunctionRpcService.cs +++ b/src/WebJobs.Script.Grpc/Server/FunctionRpcService.cs @@ -2,10 +2,9 @@ // Licensed under the MIT License. See License.txt in the project root for license information. using System; -using System.Collections.Generic; -using System.Reactive.Concurrency; using System.Reactive.Linq; using System.Threading; +using System.Threading.Channels; using System.Threading.Tasks; using Grpc.Core; using Microsoft.Azure.WebJobs.Script.Eventing; @@ -21,7 +20,6 @@ namespace Microsoft.Azure.WebJobs.Script.Grpc // TODO: move to WebJobs.Script.Grpc package and provide event stream abstraction internal class FunctionRpcService : FunctionRpc.FunctionRpcBase { - private readonly SemaphoreSlim _writeLock = new SemaphoreSlim(1, 1); private readonly IScriptEventManager _eventManager; private readonly ILogger _logger; @@ -33,64 +31,48 @@ public FunctionRpcService(IScriptEventManager eventManager, ILogger requestStream, IServerStreamWriter responseStream, ServerCallContext context) { - var cancelSource = new TaskCompletionSource(); - IDictionary outboundEventSubscriptions = new Dictionary(); - + var cancelSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var cts = CancellationTokenSource.CreateLinkedTokenSource(context.CancellationToken); + CancellationTokenRegistration ctr = cts.Token.Register(static state => ((TaskCompletionSource)state).TrySetResult(false), cancelSource); try { - context.CancellationToken.Register(() => cancelSource.TrySetResult(false)); - Func> messageAvailable = async () => + static Task> MoveNextAsync(IAsyncStreamReader requestStream, TaskCompletionSource cancelSource) { // GRPC does not accept cancellation tokens for individual reads, hence wrapper var requestTask = requestStream.MoveNext(CancellationToken.None); - var completed = await Task.WhenAny(cancelSource.Task, requestTask); - return completed.Result; - }; + return Task.WhenAny(cancelSource.Task, requestTask); + } - if (await messageAvailable()) + if (await await MoveNextAsync(requestStream, cancelSource)) { - string workerId = requestStream.Current.StartStream.WorkerId; - _logger.LogDebug("Established RPC channel. WorkerId: {workerId}", workerId); - outboundEventSubscriptions.Add(workerId, _eventManager.OfType() - .Where(evt => evt.WorkerId == workerId) - .ObserveOn(NewThreadScheduler.Default) - .Subscribe(async evt => + var currentMessage = requestStream.Current; + // expect first operation (and only the first; we don't support re-registration) to be StartStream + if (currentMessage.ContentCase == MsgType.StartStream) + { + var workerId = currentMessage.StartStream?.WorkerId; + if (!string.IsNullOrEmpty(workerId) && _eventManager.TryGetGrpcChannels(workerId, out var inbound, out var outbound)) { - try - { - if (evt.MessageType == MsgType.InvocationRequest) - { - _logger.LogTrace("Writing invocation request invocationId: {invocationId} to workerId: {workerId}", evt.Message.InvocationRequest.InvocationId, workerId); - } + // send work + _ = PushFromOutboundToGrpc(workerId, responseStream, outbound.Reader, cts.Token); - try + // this loop is "pull from gRPC and push to inbound" + do + { + currentMessage = requestStream.Current; + if (currentMessage.ContentCase == MsgType.InvocationResponse && !string.IsNullOrEmpty(currentMessage.InvocationResponse?.InvocationId)) { - // WriteAsync only allows one pending write at a time, so we - // serialize access to the stream for each subscription - await _writeLock.WaitAsync(); - await responseStream.WriteAsync(evt.Message); + _logger.LogTrace("Received invocation response for invocationId: {invocationId} from workerId: {workerId}", currentMessage.InvocationResponse.InvocationId, workerId); } - finally + var newInbound = new InboundGrpcEvent(workerId, currentMessage); + if (!inbound.Writer.TryWrite(newInbound)) { - _writeLock.Release(); + await inbound.Writer.WriteAsync(newInbound); } + currentMessage = null; // allow old messages to be collected while we wait } - catch (Exception subscribeEventEx) - { - _logger.LogError(subscribeEventEx, "Error writing message type {messageType} to workerId: {workerId}", evt.MessageType, workerId); - } - })); - - do - { - var currentMessage = requestStream.Current; - if (currentMessage.InvocationResponse != null && !string.IsNullOrEmpty(currentMessage.InvocationResponse.InvocationId)) - { - _logger.LogTrace("Received invocation response for invocationId: {invocationId} from workerId: {workerId}", currentMessage.InvocationResponse.InvocationId, workerId); + while (await await MoveNextAsync(requestStream, cancelSource)); } - _eventManager.Publish(new InboundGrpcEvent(workerId, currentMessage)); } - while (await messageAvailable()); } } catch (Exception rpcException) @@ -101,14 +83,47 @@ public override async Task EventStream(IAsyncStreamReader requ } finally { - foreach (var sub in outboundEventSubscriptions) - { - sub.Value?.Dispose(); - } + cts.Cancel(); + ctr.Dispose(); // ensure cancellationSource task completes cancelSource.TrySetResult(false); } } + + private async Task PushFromOutboundToGrpc(string workerId, IServerStreamWriter responseStream, ChannelReader source, CancellationToken cancellationToken) + { + try + { + _logger.LogDebug("Established RPC channel. WorkerId: {workerId}", workerId); + await Task.Yield(); // free up the caller + while (await source.WaitToReadAsync(cancellationToken)) + { + while (source.TryRead(out var evt)) + { + if (evt.MessageType == MsgType.InvocationRequest) + { + _logger.LogTrace("Writing invocation request invocationId: {invocationId} to workerId: {workerId}", evt.Message.InvocationRequest.InvocationId, workerId); + } + try + { + await responseStream.WriteAsync(evt.Message); + } + catch (Exception subscribeEventEx) + { + _logger.LogError(subscribeEventEx, "Error writing message type {messageType} to workerId: {workerId}", evt.MessageType, workerId); + } + } + } + } + catch (OperationCanceledException oce) when (oce.CancellationToken == cancellationToken) + { + // that's fine; normaly exit through cancellation + } + catch (Exception ex) + { + _logger.LogError(ex, "Error pushing from outbound to gRPC"); + } + } } } diff --git a/src/WebJobs.Script.Grpc/WebJobs.Script.Grpc.csproj b/src/WebJobs.Script.Grpc/WebJobs.Script.Grpc.csproj index fdab414a04..1c26ded690 100644 --- a/src/WebJobs.Script.Grpc/WebJobs.Script.Grpc.csproj +++ b/src/WebJobs.Script.Grpc/WebJobs.Script.Grpc.csproj @@ -30,6 +30,7 @@ all + diff --git a/src/WebJobs.Script/Eventing/IScriptEventManager.cs b/src/WebJobs.Script/Eventing/IScriptEventManager.cs index 889d508e96..8c3c8badb0 100644 --- a/src/WebJobs.Script/Eventing/IScriptEventManager.cs +++ b/src/WebJobs.Script/Eventing/IScriptEventManager.cs @@ -8,5 +8,11 @@ namespace Microsoft.Azure.WebJobs.Script.Eventing public interface IScriptEventManager : IObservable { void Publish(ScriptEvent scriptEvent); + + bool TryAddWorkerState(string workerId, T state); + + bool TryGetWorkerState(string workerId, out T state); + + bool TryRemoveWorkerState(string workerId, out T state); } } diff --git a/src/WebJobs.Script/Eventing/ScriptEventManager.cs b/src/WebJobs.Script/Eventing/ScriptEventManager.cs index 91c29ff40f..f3d1e360ce 100644 --- a/src/WebJobs.Script/Eventing/ScriptEventManager.cs +++ b/src/WebJobs.Script/Eventing/ScriptEventManager.cs @@ -2,13 +2,16 @@ // Licensed under the MIT License. See License.txt in the project root for license information. using System; +using System.Collections.Concurrent; using System.Reactive.Subjects; namespace Microsoft.Azure.WebJobs.Script.Eventing { - public sealed class ScriptEventManager : IScriptEventManager, IDisposable + public class ScriptEventManager : IScriptEventManager, IDisposable { private readonly Subject _subject = new Subject(); + private readonly ConcurrentDictionary<(string, Type), object> _workerState = new (); + private bool _disposed = false; public void Publish(ScriptEvent scriptEvent) @@ -47,5 +50,35 @@ private void Dispose(bool disposing) } public void Dispose() => Dispose(true); + + bool IScriptEventManager.TryAddWorkerState(string workerId, T state) + { + var key = (workerId, typeof(T)); + return _workerState.TryAdd(key, state); + } + + bool IScriptEventManager.TryGetWorkerState(string workerId, out T state) + { + var key = (workerId, typeof(T)); + if (_workerState.TryGetValue(key, out var tmp) && tmp is T typed) + { + state = typed; + return true; + } + state = default; + return false; + } + + bool IScriptEventManager.TryRemoveWorkerState(string workerId, out T state) + { + var key = (workerId, typeof(T)); + if (_workerState.TryRemove(key, out var tmp) && tmp is T typed) + { + state = typed; + return true; + } + state = default; + return false; + } } } diff --git a/src/WebJobs.Script/Workers/Rpc/RpcWorkerConstants.cs b/src/WebJobs.Script/Workers/Rpc/RpcWorkerConstants.cs index 7538284620..aee89c1929 100644 --- a/src/WebJobs.Script/Workers/Rpc/RpcWorkerConstants.cs +++ b/src/WebJobs.Script/Workers/Rpc/RpcWorkerConstants.cs @@ -62,6 +62,7 @@ public static class RpcWorkerConstants // Host Capabilities public const string V2Compatable = "V2Compatable"; + public const string MultiStream = nameof(MultiStream); // dotnet executable file path components public const string DotNetExecutableName = "dotnet"; diff --git a/test/WebJobs.Script.Tests.Integration/Rpc/WorkerConcurrencyManagerEndToEndTests.cs b/test/WebJobs.Script.Tests.Integration/Rpc/WorkerConcurrencyManagerEndToEndTests.cs index d1053d24fe..9647ca4820 100644 --- a/test/WebJobs.Script.Tests.Integration/Rpc/WorkerConcurrencyManagerEndToEndTests.cs +++ b/test/WebJobs.Script.Tests.Integration/Rpc/WorkerConcurrencyManagerEndToEndTests.cs @@ -1,8 +1,11 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Threading; +using System.Threading.Channels; using System.Threading.Tasks; using Microsoft.Azure.WebJobs.Script.Eventing; +using Microsoft.Azure.WebJobs.Script.Grpc.Eventing; using Microsoft.Azure.WebJobs.Script.Workers.Rpc; using Xunit; @@ -22,8 +25,7 @@ public async Task WorkerStatus_NewWorkerAdded() { RpcFunctionInvocationDispatcher fd = null; IEnumerable channels = null; - // Latency > 1s - TestScriptEventManager.WaitBeforePublish = TimeSpan.FromSeconds(2); + await TestHelpers.Await(async () => { fd = Fixture.JobHost.FunctionDispatcher as RpcFunctionInvocationDispatcher; @@ -34,9 +36,11 @@ await TestHelpers.Await(async () => public class TestFixture : ScriptHostEndToEndTestFixture { + // Latency > 1s public TestFixture() : base(@"TestScripts\Node", "node", RpcWorkerConstants.NodeLanguageWorkerName, startHost: true, functions: new[] { "HttpTrigger" }, - addWorkerConcurrency: true) + addWorkerConcurrency: true, + addWorkerDelay: TimeSpan.FromSeconds(2)) { } } @@ -44,19 +48,16 @@ public TestFixture() : base(@"TestScripts\Node", "node", RpcWorkerConstants.Node internal class TestScriptEventManager : IScriptEventManager, IDisposable { private readonly IScriptEventManager _scriptEventManager; + private readonly TimeSpan _delay; - - public TestScriptEventManager() + public TestScriptEventManager(TimeSpan delay) { _scriptEventManager = new ScriptEventManager(); + _delay = delay; } - public static TimeSpan WaitBeforePublish; - - public async void Publish(ScriptEvent scriptEvent) + public void Publish(ScriptEvent scriptEvent) { - // Emulate long worker status latency - await Task.Delay(WaitBeforePublish); try { _scriptEventManager.Publish(scriptEvent); @@ -67,12 +68,52 @@ public async void Publish(ScriptEvent scriptEvent) } } - public IDisposable Subscribe(IObserver observer) + public IDisposable Subscribe(IObserver observer) => _scriptEventManager.Subscribe(observer); + + public void Dispose() => ((IDisposable)_scriptEventManager).Dispose(); + public bool TryAddWorkerState(string workerId, T state) { - return _scriptEventManager.Subscribe(observer); + // Swap for a channel that imposes a delay into the pipe + if (typeof(T) == typeof(Channel) && _delay > TimeSpan.Zero) + { + state = (T)(object)(new DelayedOutboundChannel(_delay)); + } + return _scriptEventManager.TryAddWorkerState(workerId, state); } - public void Dispose() => ((IDisposable)_scriptEventManager).Dispose(); + public bool TryGetWorkerState(string workerId, out T state) + => _scriptEventManager.TryGetWorkerState(workerId, out state); + + public bool TryRemoveWorkerState(string workerId, out T state) + => _scriptEventManager.TryRemoveWorkerState(workerId, out state); + + + public class DelayedOutboundChannel : Channel + { + public DelayedOutboundChannel(TimeSpan delay) + { + var toWrap = Channel.CreateUnbounded(GrpcEventExtensions.OutboundOptions); + Reader = toWrap.Reader; + Writer = new DelayedChannelWriter(toWrap.Writer, delay); + } + } + + public class DelayedChannelWriter : ChannelWriter + { + private readonly TimeSpan _delay; + private readonly ChannelWriter _inner; + + public DelayedChannelWriter(ChannelWriter toWrap, TimeSpan delay) => (_inner, _delay) = (toWrap, delay); + + public override bool TryWrite(T item) => false; // Always fail, so we bounce to WriteAsync + public override ValueTask WaitToWriteAsync(CancellationToken cancellationToken = default) => _inner.WaitToWriteAsync(cancellationToken); + + public override async ValueTask WriteAsync(T item, CancellationToken cancellationToken = default) + { + await Task.Delay(_delay, cancellationToken); + await _inner.WriteAsync(item, cancellationToken); + } + } } } } diff --git a/test/WebJobs.Script.Tests.Integration/ScriptHostEndToEnd/ScriptHostEndToEndTestFixture.cs b/test/WebJobs.Script.Tests.Integration/ScriptHostEndToEnd/ScriptHostEndToEndTestFixture.cs index 5fcc5b6204..1b15d8d5dd 100644 --- a/test/WebJobs.Script.Tests.Integration/ScriptHostEndToEnd/ScriptHostEndToEndTestFixture.cs +++ b/test/WebJobs.Script.Tests.Integration/ScriptHostEndToEnd/ScriptHostEndToEndTestFixture.cs @@ -40,9 +40,10 @@ public abstract class ScriptHostEndToEndTestFixture : IAsyncLifetime private readonly ICollection _functions; private readonly string _functionsWorkerLanguage; private readonly bool _addWorkerConcurrency; + private readonly TimeSpan? _addWorkerDelay; protected ScriptHostEndToEndTestFixture(string rootPath, string testId, string functionsWorkerLanguage, - bool startHost = true, ICollection functions = null, bool addWorkerConcurrency = false) + bool startHost = true, ICollection functions = null, bool addWorkerConcurrency = false, TimeSpan? addWorkerDelay = null) { _settingsManager = ScriptSettingsManager.Instance; FixtureId = testId; @@ -56,6 +57,7 @@ protected ScriptHostEndToEndTestFixture(string rootPath, string testId, string f _functions = functions; _functionsWorkerLanguage = functionsWorkerLanguage; _addWorkerConcurrency = addWorkerConcurrency; + _addWorkerDelay = addWorkerDelay; } public TestLoggerProvider LoggerProvider { get; } @@ -155,9 +157,9 @@ public async Task InitializeAsync() services.AddSingleton(); } services.AddSingleton(); - if (_addWorkerConcurrency) + if (_addWorkerConcurrency && _addWorkerDelay > TimeSpan.Zero) { - services.AddSingleton(); + services.AddSingleton(new WorkerConcurrencyManagerEndToEndTests.TestScriptEventManager(_addWorkerDelay.Value)); } ConfigureServices(services); diff --git a/test/WebJobs.Script.Tests.Shared/TestLogger.cs b/test/WebJobs.Script.Tests.Shared/TestLogger.cs index 45ac491d4b..c71f3bcee8 100644 --- a/test/WebJobs.Script.Tests.Shared/TestLogger.cs +++ b/test/WebJobs.Script.Tests.Shared/TestLogger.cs @@ -6,6 +6,7 @@ using System.Diagnostics; using System.Linq; using Microsoft.Extensions.Logging; +using Xunit.Abstractions; namespace Microsoft.Azure.WebJobs.Script.Tests { @@ -20,11 +21,13 @@ public class TestLogger : ILogger { private readonly object _syncLock = new object(); private readonly IExternalScopeProvider _scopeProvider; - private IList _logMessages = new List(); + private readonly IList _logMessages = new List(); + private readonly ITestOutputHelper _testOutput; // optionally write direct to the test output - public TestLogger(string category) + public TestLogger(string category, ITestOutputHelper testOutput = null) : this(category, new LoggerExternalScopeProvider()) { + _testOutput = testOutput; } public TestLogger(string category, IExternalScopeProvider scopeProvider) @@ -78,6 +81,7 @@ public void Log(LogLevel logLevel, EventId eventId, TState state, Except { _logMessages.Add(logMessage); } + _testOutput?.WriteLine($"{logLevel}: {formatter(state, exception)}"); } } diff --git a/test/WebJobs.Script.Tests.Shared/TestScriptEventManager.cs b/test/WebJobs.Script.Tests.Shared/TestScriptEventManager.cs index 62bddceae6..d92da6cc3c 100644 --- a/test/WebJobs.Script.Tests.Shared/TestScriptEventManager.cs +++ b/test/WebJobs.Script.Tests.Shared/TestScriptEventManager.cs @@ -2,6 +2,7 @@ // Licensed under the MIT License. See License.txt in the project root for license information. using System; +using System.Threading.Channels; using Microsoft.Azure.WebJobs.Script.Eventing; namespace Microsoft.Azure.WebJobs.Script.Tests @@ -16,5 +17,26 @@ public IDisposable Subscribe(IObserver observer) { return null; } + + public bool TryGetDedicatedChannelFor(string workerId, out Channel channel) where T : ScriptEvent + { + channel = default; + return false; + } + + bool IScriptEventManager.TryAddWorkerState(string workerId, T state) + => false; + + bool IScriptEventManager.TryGetWorkerState(string workerId, out T state) + { + state = default; + return false; + } + + bool IScriptEventManager.TryRemoveWorkerState(string workerId, out T state) + { + state = default; + return false; + } } } diff --git a/test/WebJobs.Script.Tests/Configuration/DefaultDependencyValidatorTests.cs b/test/WebJobs.Script.Tests/Configuration/DefaultDependencyValidatorTests.cs index bcb6657413..8c87300597 100644 --- a/test/WebJobs.Script.Tests/Configuration/DefaultDependencyValidatorTests.cs +++ b/test/WebJobs.Script.Tests/Configuration/DefaultDependencyValidatorTests.cs @@ -6,6 +6,7 @@ using System.IO; using System.Linq; using System.Threading; +using System.Threading.Channels; using System.Threading.Tasks; using Microsoft.AspNetCore.Hosting; using Microsoft.Azure.WebJobs.Script.Diagnostics; @@ -151,6 +152,27 @@ public IDisposable Subscribe(IObserver observer) { return null; } + + public bool TryGetDedicatedChannelFor(string workerId, out Channel channel) where T : ScriptEvent + { + channel = null; + return false; + } + + bool IScriptEventManager.TryAddWorkerState(string workerId, T state) + => false; + + bool IScriptEventManager.TryGetWorkerState(string workerId, out T state) + { + state = default; + return false; + } + + bool IScriptEventManager.TryRemoveWorkerState(string workerId, out T state) + { + state = default; + return false; + } } private class MyMetricsLogger : IMetricsLogger diff --git a/test/WebJobs.Script.Tests/Extensions/InboundGrpcEventExtensionsTests.cs b/test/WebJobs.Script.Tests/Extensions/InboundGrpcEventExtensionsTests.cs index 9d63c730d8..c3baf8d6df 100644 --- a/test/WebJobs.Script.Tests/Extensions/InboundGrpcEventExtensionsTests.cs +++ b/test/WebJobs.Script.Tests/Extensions/InboundGrpcEventExtensionsTests.cs @@ -25,7 +25,7 @@ public void TestLogCategories(RpcLogCategory categoryToTest) } }); - Assert.True(inboundEvent.IsLogOfCategory(categoryToTest)); + Assert.True(inboundEvent.Message.RpcLog.LogCategory == categoryToTest); } } } diff --git a/test/WebJobs.Script.Tests/Workers/Rpc/GrpcWorkerChannelTests.cs b/test/WebJobs.Script.Tests/Workers/Rpc/GrpcWorkerChannelTests.cs index bd5552a61e..475901c870 100644 --- a/test/WebJobs.Script.Tests/Workers/Rpc/GrpcWorkerChannelTests.cs +++ b/test/WebJobs.Script.Tests/Workers/Rpc/GrpcWorkerChannelTests.cs @@ -25,6 +25,7 @@ using Microsoft.Extensions.Options; using Moq; using Xunit; +using Xunit.Abstractions; namespace Microsoft.Azure.WebJobs.Script.Tests.Workers.Rpc { @@ -38,7 +39,6 @@ public class GrpcWorkerChannelTests : IDisposable private readonly string _workerId = "testWorkerId"; private readonly string _scriptRootPath = "c:\testdir"; private readonly IScriptEventManager _eventManager = new ScriptEventManager(); - private readonly Mock _eventManagerMock = new Mock(); private readonly TestMetricsLogger _metricsLogger = new TestMetricsLogger(); private readonly Mock _mockConsoleLogger = new Mock(); private readonly Mock _mockFunctionRpcService = new Mock(); @@ -54,12 +54,14 @@ public class GrpcWorkerChannelTests : IDisposable private readonly ISharedMemoryManager _sharedMemoryManager; private readonly IFunctionDataCache _functionDataCache; private readonly IOptions _workerConcurrencyOptions; + private readonly ITestOutputHelper _testOutput; private GrpcWorkerChannel _workerChannel; - private GrpcWorkerChannel _workerChannelWithMockEventManager; - public GrpcWorkerChannelTests() + public GrpcWorkerChannelTests(ITestOutputHelper testOutput) { - _logger = new TestLogger("FunctionDispatcherTests"); + _eventManager.AddGrpcChannels(_workerId); + _testOutput = testOutput; + _logger = new TestLogger("FunctionDispatcherTests", testOutput); _testFunctionRpcService = new TestFunctionRpcService(_eventManager, _workerId, _logger, _expectedLogMsg); _testWorkerConfig = TestHelpers.GetTestWorkerConfigs().FirstOrDefault(); _testWorkerConfig.CountOptions.ProcessStartupTimeout = TimeSpan.FromSeconds(5); @@ -95,6 +97,11 @@ public GrpcWorkerChannelTests() }; _hostOptionsMonitor = TestHelpers.CreateOptionsMonitor(hostOptions); + _testEnvironment.SetEnvironmentVariable("APPLICATIONINSIGHTS_ENABLE_AGENT", "true"); + } + + private Task CreateDefaultWorkerChannel(bool autoStart = true, IDictionary capabilities = null) + { _workerChannel = new GrpcWorkerChannel( _workerId, _eventManager, @@ -109,22 +116,31 @@ public GrpcWorkerChannelTests() _functionDataCache, _workerConcurrencyOptions); - _eventManagerMock.Setup(proxy => proxy.Publish(It.IsAny())).Verifiable(); - _workerChannelWithMockEventManager = new GrpcWorkerChannel( - _workerId, - _eventManagerMock.Object, - _testWorkerConfig, - _mockrpcWorkerProcess.Object, - _logger, - _metricsLogger, - 0, - _testEnvironment, - _hostOptionsMonitor, - _sharedMemoryManager, - _functionDataCache, - _workerConcurrencyOptions); + if (autoStart) + { + // for most tests, we want things to be responsive to inbound messages + _testFunctionRpcService.OnMessage(StreamingMessage.ContentOneofCase.WorkerInitRequest, + _ => _testFunctionRpcService.PublishWorkerInitResponseEvent(capabilities)); + return _workerChannel.StartWorkerProcessAsync(CancellationToken.None); + } + else + { + return Task.CompletedTask; + } + } - _testEnvironment.SetEnvironmentVariable("APPLICATIONINSIGHTS_ENABLE_AGENT", "true"); + private void ShowOutput(string message) + => _testOutput?.WriteLine(message); + + private void ShowOutput(IList messages) + { + if (_testOutput is not null && messages is not null) + { + foreach (var msg in messages) + { + _testOutput.WriteLine(msg.FormattedMessage); + } + } } public void Dispose() @@ -135,10 +151,9 @@ public void Dispose() [Fact] public async Task StartWorkerProcessAsync_ThrowsTaskCanceledException_IfDisposed() { - var initTask = _workerChannel.StartWorkerProcessAsync(CancellationToken.None); + var initTask = CreateDefaultWorkerChannel(); + _workerChannel.Dispose(); - _testFunctionRpcService.PublishStartStreamEvent(_workerId); - _testFunctionRpcService.PublishWorkerInitResponseEvent(); await Assert.ThrowsAsync(async () => { await initTask; @@ -146,14 +161,9 @@ await Assert.ThrowsAsync(async () => } [Fact] - public void WorkerChannel_Dispose_With_WorkerTerminateCapability() + public async Task WorkerChannel_Dispose_With_WorkerTerminateCapability() { - var initTask = _workerChannel.StartWorkerProcessAsync(CancellationToken.None); - - IDictionary capabilities = new Dictionary() - { - { RpcWorkerConstants.HandlesWorkerTerminateMessage, "1" } - }; + await CreateDefaultWorkerChannel(capabilities: new Dictionary() { { RpcWorkerConstants.HandlesWorkerTerminateMessage, "1" } }); StartStream startStream = new StartStream() { @@ -167,8 +177,10 @@ public void WorkerChannel_Dispose_With_WorkerTerminateCapability() // Send worker init request and enable the capabilities GrpcEvent rpcEvent = new GrpcEvent(_workerId, startStreamMessage); + _testFunctionRpcService.AutoReply(StreamingMessage.ContentOneofCase.WorkerInitRequest); _workerChannel.SendWorkerInitRequest(rpcEvent); - _testFunctionRpcService.PublishWorkerInitResponseEvent(capabilities); + + await Task.Delay(500); _workerChannel.Dispose(); var traces = _logger.GetLogMessages(); @@ -177,9 +189,9 @@ public void WorkerChannel_Dispose_With_WorkerTerminateCapability() } [Fact] - public void WorkerChannel_Dispose_Without_WorkerTerminateCapability() + public async Task WorkerChannel_Dispose_Without_WorkerTerminateCapability() { - var initTask = _workerChannel.StartWorkerProcessAsync(CancellationToken.None); + await CreateDefaultWorkerChannel(); _workerChannel.Dispose(); var traces = _logger.GetLogMessages(); @@ -190,10 +202,7 @@ public void WorkerChannel_Dispose_Without_WorkerTerminateCapability() [Fact] public async Task StartWorkerProcessAsync_Invoked_SetupFunctionBuffers_Verify_ReadyForInvocation() { - var initTask = _workerChannel.StartWorkerProcessAsync(CancellationToken.None); - _testFunctionRpcService.PublishStartStreamEvent(_workerId); - _testFunctionRpcService.PublishWorkerInitResponseEvent(); - await initTask; + await CreateDefaultWorkerChannel(); _mockrpcWorkerProcess.Verify(m => m.StartProcessAsync(), Times.Once); Assert.False(_workerChannel.IsChannelReadyForInvocations()); _workerChannel.SetupFunctionInvocationBuffers(GetTestFunctionsList("node")); @@ -203,20 +212,26 @@ public async Task StartWorkerProcessAsync_Invoked_SetupFunctionBuffers_Verify_Re [Fact] public async Task DisposingChannel_NotReadyForInvocation() { - var initTask = _workerChannel.StartWorkerProcessAsync(CancellationToken.None); - _testFunctionRpcService.PublishStartStreamEvent(_workerId); - _testFunctionRpcService.PublishWorkerInitResponseEvent(); - await initTask; - Assert.False(_workerChannel.IsChannelReadyForInvocations()); - _workerChannel.SetupFunctionInvocationBuffers(GetTestFunctionsList("node")); - Assert.True(_workerChannel.IsChannelReadyForInvocations()); - _workerChannel.Dispose(); - Assert.False(_workerChannel.IsChannelReadyForInvocations()); + try + { + await CreateDefaultWorkerChannel(); + Assert.False(_workerChannel.IsChannelReadyForInvocations()); + _workerChannel.SetupFunctionInvocationBuffers(GetTestFunctionsList("node")); + Assert.True(_workerChannel.IsChannelReadyForInvocations()); + _workerChannel.Dispose(); + Assert.False(_workerChannel.IsChannelReadyForInvocations()); + } + finally + { + var traces = _logger.GetLogMessages(); + ShowOutput(traces); + } } [Fact] public void SetupFunctionBuffers_Verify_ReadyForInvocation_Returns_False() { + CreateDefaultWorkerChannel(); Assert.False(_workerChannel.IsChannelReadyForInvocations()); _workerChannel.SetupFunctionInvocationBuffers(GetTestFunctionsList("node")); Assert.False(_workerChannel.IsChannelReadyForInvocations()); @@ -225,6 +240,7 @@ public void SetupFunctionBuffers_Verify_ReadyForInvocation_Returns_False() [Fact] public async Task StartWorkerProcessAsync_TimesOut() { + await CreateDefaultWorkerChannel(autoStart: false); // suppress for timeout var initTask = _workerChannel.StartWorkerProcessAsync(CancellationToken.None); await Assert.ThrowsAsync(async () => await initTask); } @@ -232,6 +248,7 @@ public async Task StartWorkerProcessAsync_TimesOut() [Fact] public async Task SendEnvironmentReloadRequest_Generates_ExpectedMetrics() { + await CreateDefaultWorkerChannel(); _metricsLogger.ClearCollections(); Task waitForMetricsTask = Task.Factory.StartNew(() => { @@ -247,6 +264,7 @@ public async Task SendEnvironmentReloadRequest_Generates_ExpectedMetrics() [Fact] public async Task StartWorkerProcessAsync_WorkerProcess_Throws() { + // note: uses custom worker channel Mock mockrpcWorkerProcessThatThrows = new Mock(); mockrpcWorkerProcessThatThrows.Setup(m => m.StartProcessAsync()).Throws(); @@ -267,8 +285,9 @@ public async Task StartWorkerProcessAsync_WorkerProcess_Throws() } [Fact] - public void SendWorkerInitRequest_PublishesOutboundEvent() + public async Task SendWorkerInitRequest_PublishesOutboundEvent() { + await CreateDefaultWorkerChannel(autoStart: false); // we'll do it manually here StartStream startStream = new StartStream() { WorkerId = _workerId @@ -278,8 +297,9 @@ public void SendWorkerInitRequest_PublishesOutboundEvent() StartStream = startStream }; GrpcEvent rpcEvent = new GrpcEvent(_workerId, startStreamMessage); + _testFunctionRpcService.AutoReply(StreamingMessage.ContentOneofCase.WorkerInitRequest); _workerChannel.SendWorkerInitRequest(rpcEvent); - _testFunctionRpcService.PublishWorkerInitResponseEvent(); + await Task.Delay(500); var traces = _logger.GetLogMessages(); Assert.True(traces.Any(m => string.Equals(m.FormattedMessage, _expectedLogMsg))); } @@ -287,6 +307,7 @@ public void SendWorkerInitRequest_PublishesOutboundEvent() [Fact] public void WorkerInitRequest_Expected() { + CreateDefaultWorkerChannel(autoStart: false); // doesn't actually need to run; just not be null WorkerInitRequest initRequest = _workerChannel.GetWorkerInitRequest(); Assert.NotNull(initRequest.WorkerDirectory); Assert.NotNull(initRequest.FunctionAppDirectory); @@ -297,20 +318,11 @@ public void WorkerInitRequest_Expected() } [Fact] - public void SendWorkerInitRequest_PublishesOutboundEvent_V2Compatable() + public async Task SendWorkerInitRequest_PublishesOutboundEvent_V2Compatable() { _testEnvironment.SetEnvironmentVariable(EnvironmentSettingNames.FunctionsV2CompatibilityModeKey, "true"); - StartStream startStream = new StartStream() - { - WorkerId = _workerId - }; - StreamingMessage startStreamMessage = new StreamingMessage() - { - StartStream = startStream - }; - GrpcEvent rpcEvent = new GrpcEvent(_workerId, startStreamMessage); - _workerChannel.SendWorkerInitRequest(rpcEvent); - _testFunctionRpcService.PublishWorkerInitResponseEvent(); + await CreateDefaultWorkerChannel(); + await Task.Delay(500); var traces = _logger.GetLogMessages(); Assert.True(traces.Any(m => string.Equals(m.FormattedMessage, _expectedLogMsg))); Assert.True(traces.Any(m => string.Equals(m.FormattedMessage, "Worker and host running in V2 compatibility mode"))); @@ -321,9 +333,11 @@ public void SendWorkerInitRequest_PublishesOutboundEvent_V2Compatable() [InlineData(RpcLog.Types.Level.Error, RpcLog.Types.Level.Error)] [InlineData(RpcLog.Types.Level.Warning, RpcLog.Types.Level.Warning)] [InlineData(RpcLog.Types.Level.Trace, RpcLog.Types.Level.Information)] - public void SendSystemLogMessage_PublishesSystemLogMessage(RpcLog.Types.Level levelToTest, RpcLog.Types.Level expectedLogLevel) + public async Task SendSystemLogMessage_PublishesSystemLogMessage(RpcLog.Types.Level levelToTest, RpcLog.Types.Level expectedLogLevel) { + await CreateDefaultWorkerChannel(); _testFunctionRpcService.PublishSystemLogEvent(levelToTest); + await Task.Delay(500); var traces = _logger.GetLogMessages(); Assert.True(traces.Any(m => string.Equals(m.FormattedMessage, _expectedSystemLogMessage) && m.Level.ToString().Equals(expectedLogLevel.ToString()))); } @@ -331,8 +345,10 @@ public void SendSystemLogMessage_PublishesSystemLogMessage(RpcLog.Types.Level le [Fact] public async Task SendInvocationRequest_PublishesOutboundEvent() { + await CreateDefaultWorkerChannel(); ScriptInvocationContext scriptInvocationContext = GetTestScriptInvocationContext(Guid.NewGuid(), null); await _workerChannel.SendInvocationRequest(scriptInvocationContext); + await Task.Delay(500); var traces = _logger.GetLogMessages(); Assert.True(traces.Any(m => string.Equals(m.FormattedMessage, _expectedLogMsg))); } @@ -340,6 +356,7 @@ public async Task SendInvocationRequest_PublishesOutboundEvent() [Fact] public async Task SendInvocationRequest_IsInExecutingInvocation() { + await CreateDefaultWorkerChannel(); ScriptInvocationContext scriptInvocationContext = GetTestScriptInvocationContext(Guid.NewGuid(), null); await _workerChannel.SendInvocationRequest(scriptInvocationContext); Assert.True(_workerChannel.IsExecutingInvocation(scriptInvocationContext.ExecutionContext.InvocationId.ToString())); @@ -352,11 +369,12 @@ public async Task SendInvocationRequest_IsInExecutingInvocation() [Fact] public async Task SendInvocationRequest_InputsTransferredOverSharedMemory() { - EnableSharedMemoryDataTransfer(); + await CreateSharedMemoryEnabledWorkerChannel(); // Send invocation which will be using RpcSharedMemory for the inputs ScriptInvocationContext scriptInvocationContext = GetTestScriptInvocationContextWithSharedMemoryInputs(Guid.NewGuid(), null); await _workerChannel.SendInvocationRequest(scriptInvocationContext); + await Task.Delay(500); var traces = _logger.GetLogMessages(); Assert.True(traces.Any(m => string.Equals(m.FormattedMessage, _expectedLogMsg))); } @@ -368,19 +386,11 @@ public async Task SendInvocationRequest_SignalCancellation_WithCapability_SendsI var invocationId = Guid.NewGuid(); var expectedCancellationLog = $"Sending invocation cancel request for InvocationId {invocationId.ToString()}"; - IDictionary capabilities = new Dictionary() - { - { RpcWorkerConstants.HandlesInvocationCancelMessage, "1" } - }; - var cts = new CancellationTokenSource(); cts.CancelAfter(cancellationWaitTimeMs); var token = cts.Token; - var initTask = _workerChannel.StartWorkerProcessAsync(CancellationToken.None); - _testFunctionRpcService.PublishStartStreamEvent(_workerId); - _testFunctionRpcService.PublishWorkerInitResponseEvent(capabilities); - await initTask; + await CreateDefaultWorkerChannel(capabilities: new Dictionary() { { RpcWorkerConstants.HandlesInvocationCancelMessage, "true" } }); var scriptInvocationContext = GetTestScriptInvocationContext(invocationId, null, token); await _workerChannel.SendInvocationRequest(scriptInvocationContext); @@ -392,6 +402,7 @@ public async Task SendInvocationRequest_SignalCancellation_WithCapability_SendsI break; } } + await Task.Delay(500); var traces = _logger.GetLogMessages(); Assert.True(traces.Any(m => string.Equals(m.FormattedMessage, expectedCancellationLog))); @@ -408,10 +419,7 @@ public async Task SendInvocationRequest_SignalCancellation_WithoutCapability_NoA cts.CancelAfter(cancellationWaitTimeMs); var token = cts.Token; - var initTask = _workerChannel.StartWorkerProcessAsync(CancellationToken.None); - _testFunctionRpcService.PublishStartStreamEvent(_workerId); - _testFunctionRpcService.PublishWorkerInitResponseEvent(); - await initTask; + await CreateDefaultWorkerChannel(); var scriptInvocationContext = GetTestScriptInvocationContext(invocationId, null, token); await _workerChannel.SendInvocationRequest(scriptInvocationContext); @@ -435,20 +443,11 @@ public async Task SendInvocationRequest_CancellationAlreadyRequested_ResultSourc var invocationId = Guid.NewGuid(); var expectedCancellationLog = "Cancellation has been requested, cancelling invocation request"; - IDictionary capabilities = new Dictionary() - { - { RpcWorkerConstants.HandlesInvocationCancelMessage, "1" } - }; - var cts = new CancellationTokenSource(); cts.CancelAfter(cancellationWaitTimeMs); var token = cts.Token; - var initTask = _workerChannel.StartWorkerProcessAsync(CancellationToken.None); - _testFunctionRpcService.PublishStartStreamEvent(_workerId); - _testFunctionRpcService.PublishWorkerInitResponseEvent(capabilities); - await initTask; - + await CreateDefaultWorkerChannel(capabilities: new Dictionary() { { RpcWorkerConstants.HandlesInvocationCancelMessage, "true" } }); while (!token.IsCancellationRequested) { await Task.Delay(1000); @@ -473,19 +472,10 @@ public async Task SendInvocationCancelRequest_PublishesOutboundEvent() var invocationId = Guid.NewGuid(); var expectedCancellationLog = $"Sending invocation cancel request for InvocationId {invocationId.ToString()}"; - IDictionary capabilities = new Dictionary() - { - { RpcWorkerConstants.HandlesInvocationCancelMessage, "1" } - }; - - var initTask = _workerChannel.StartWorkerProcessAsync(CancellationToken.None); - _testFunctionRpcService.PublishStartStreamEvent(_workerId); - _testFunctionRpcService.PublishWorkerInitResponseEvent(capabilities); - await initTask; - + await CreateDefaultWorkerChannel(capabilities: new Dictionary() { { RpcWorkerConstants.HandlesInvocationCancelMessage, "true" } }); var scriptInvocationContext = GetTestScriptInvocationContext(invocationId, null); _workerChannel.SendInvocationCancel(invocationId.ToString()); - + await Task.Delay(500); var traces = _logger.GetLogMessages(); Assert.True(traces.Any(m => string.Equals(m.FormattedMessage, expectedCancellationLog))); Assert.True(traces.Any(m => string.Equals(m.FormattedMessage, _expectedLogMsg))); @@ -496,6 +486,7 @@ public async Task SendInvocationCancelRequest_PublishesOutboundEvent() [Fact] public async Task Drain_Verify() { + // note: uses custom worker channel var resultSource = new TaskCompletionSource(); Guid invocationId = Guid.NewGuid(); GrpcWorkerChannel channel = new GrpcWorkerChannel( @@ -531,6 +522,7 @@ await channel.InvokeResponse(new InvocationResponse [Fact] public async Task InFlight_Functions_FailedWithException() { + await CreateDefaultWorkerChannel(); var resultSource = new TaskCompletionSource(); ScriptInvocationContext scriptInvocationContext = GetTestScriptInvocationContext(Guid.NewGuid(), resultSource); await _workerChannel.SendInvocationRequest(scriptInvocationContext); @@ -543,20 +535,24 @@ public async Task InFlight_Functions_FailedWithException() } [Fact] - public void SendLoadRequests_PublishesOutboundEvents() + public async Task SendLoadRequests_PublishesOutboundEvents() { + await CreateDefaultWorkerChannel(); _metricsLogger.ClearCollections(); _workerChannel.SetupFunctionInvocationBuffers(GetTestFunctionsList("node")); _workerChannel.SendFunctionLoadRequests(null, TimeSpan.FromMinutes(5)); + await Task.Delay(500); var traces = _logger.GetLogMessages(); var functionLoadLogs = traces.Where(m => string.Equals(m.FormattedMessage, _expectedLogMsg)); AreExpectedMetricsGenerated(); - Assert.True(functionLoadLogs.Count() == 2); + Assert.Equal(3, functionLoadLogs.Count()); // one WorkInitRequest, two FunctionLoadRequest } [Fact] - public void SendLoadRequestCollection_PublishesOutboundEvents() + public async Task SendLoadRequestCollection_PublishesOutboundEvents() { + await CreateDefaultWorkerChannel(capabilities: new Dictionary() { { RpcWorkerConstants.SupportsLoadResponseCollection, "true" } }); + StartStream startStream = new StartStream() { WorkerId = _workerId @@ -568,20 +564,24 @@ public void SendLoadRequestCollection_PublishesOutboundEvents() GrpcEvent rpcEvent = new GrpcEvent(_workerId, startStreamMessage); _workerChannel.SendWorkerInitRequest(rpcEvent); _testFunctionRpcService.PublishWorkerInitResponseEvent(new Dictionary() { { RpcWorkerConstants.SupportsLoadResponseCollection, "true" } }); + _metricsLogger.ClearCollections(); IEnumerable functionMetadata = GetTestFunctionsList("node"); _workerChannel.SetupFunctionInvocationBuffers(functionMetadata); _workerChannel.SendFunctionLoadRequests(null, TimeSpan.FromMinutes(5)); + await Task.Delay(500); var traces = _logger.GetLogMessages(); + ShowOutput(traces); var functionLoadLogs = traces.Where(m => string.Equals(m.FormattedMessage, _expectedLogMsg)); AreExpectedMetricsGenerated(); - Assert.True(functionLoadLogs.Count() == 2); + Assert.Equal(3, functionLoadLogs.Count()); Assert.True(traces.Any(m => string.Equals(m.FormattedMessage, string.Format("Sending FunctionLoadRequestCollection with number of functions:'{0}'", functionMetadata.ToList().Count)))); } [Fact] - public void SendLoadRequests_PublishesOutboundEvents_OrdersDisabled() + public async Task SendLoadRequests_PublishesOutboundEvents_OrdersDisabled() { + await CreateDefaultWorkerChannel(); var funcName = "ADisabledFunc"; var functions = GetTestFunctionsList_WithDisabled("node", funcName); @@ -590,44 +590,64 @@ public void SendLoadRequests_PublishesOutboundEvents_OrdersDisabled() _workerChannel.SetupFunctionInvocationBuffers(functions); _workerChannel.SendFunctionLoadRequests(null, TimeSpan.FromMinutes(5)); + await Task.Delay(500); var traces = _logger.GetLogMessages(); + ShowOutput(traces); var functionLoadLogs = traces.Where(m => m.FormattedMessage?.Contains(_expectedLoadMsgPartial) ?? false); var t = functionLoadLogs.Last().FormattedMessage; // Make sure that disabled func shows up last Assert.True(functionLoadLogs.Last().FormattedMessage.Contains(funcName)); Assert.False(functionLoadLogs.First().FormattedMessage.Contains(funcName)); - Assert.True(functionLoadLogs.Count() == 3); + Assert.Equal(3, functionLoadLogs.Count()); } [Fact] - public void SendLoadRequests_DoesNotTimeout_FunctionTimeoutNotSet() + public async Task SendLoadRequests_DoesNotTimeout_FunctionTimeoutNotSet() { + await CreateDefaultWorkerChannel(); var funcName = "ADisabledFunc"; var functions = GetTestFunctionsList_WithDisabled("node", funcName); _workerChannel.SetupFunctionInvocationBuffers(functions); _workerChannel.SendFunctionLoadRequests(null, null); + await Task.Delay(500); var traces = _logger.GetLogMessages(); + ShowOutput(traces); var errorLogs = traces.Where(m => m.Level == LogLevel.Error); Assert.Empty(errorLogs); } [Fact] - public void SendSendFunctionEnvironmentReloadRequest_PublishesOutboundEvents() + public async Task SendSendFunctionEnvironmentReloadRequest_PublishesOutboundEvents() { - Environment.SetEnvironmentVariable("TestNull", null); - Environment.SetEnvironmentVariable("TestEmpty", string.Empty); - Environment.SetEnvironmentVariable("TestValid", "TestValue"); - _workerChannel.SendFunctionEnvironmentReloadRequest(); - _testFunctionRpcService.PublishFunctionEnvironmentReloadResponseEvent(); + await CreateDefaultWorkerChannel(); + try + { + Environment.SetEnvironmentVariable("TestNull", null); + Environment.SetEnvironmentVariable("TestEmpty", string.Empty); + Environment.SetEnvironmentVariable("TestValid", "TestValue"); + _testFunctionRpcService.AutoReply(StreamingMessage.ContentOneofCase.FunctionEnvironmentReloadRequest); + var pending = _workerChannel.SendFunctionEnvironmentReloadRequest(); + await Task.Delay(500); + await pending; // this can timeout + } + catch + { + // show what we got even if we fail + var tmp = _logger.GetLogMessages(); + ShowOutput(tmp); + throw; + } var traces = _logger.GetLogMessages(); + ShowOutput(traces); var functionLoadLogs = traces.Where(m => string.Equals(m.FormattedMessage, "Sending FunctionEnvironmentReloadRequest to WorkerProcess with Pid: '910'")); - Assert.True(functionLoadLogs.Count() == 1); + Assert.Equal(1, functionLoadLogs.Count()); } [Fact] public async Task SendSendFunctionEnvironmentReloadRequest_ThrowsTimeout() { + await CreateDefaultWorkerChannel(); var reloadTask = _workerChannel.SendFunctionEnvironmentReloadRequest(); await Assert.ThrowsAsync(async () => await reloadTask); } @@ -635,6 +655,7 @@ public async Task SendSendFunctionEnvironmentReloadRequest_ThrowsTimeout() [Fact] public void SendFunctionEnvironmentReloadRequest_SanitizedEnvironmentVariables() { + CreateDefaultWorkerChannel(); var environmentVariables = new Dictionary() { { "TestNull", null }, @@ -654,6 +675,7 @@ public void SendFunctionEnvironmentReloadRequest_SanitizedEnvironmentVariables() [Fact] public void SendFunctionEnvironmentReloadRequest_WithDirectory() { + CreateDefaultWorkerChannel(); var environmentVariables = new Dictionary() { { "TestValid", "TestValue" } @@ -665,33 +687,42 @@ public void SendFunctionEnvironmentReloadRequest_WithDirectory() } [Fact] - public void ReceivesInboundEvent_InvocationResponse() + public async Task ReceivesInboundEvent_InvocationResponse() { + await CreateDefaultWorkerChannel(); _testFunctionRpcService.PublishInvocationResponseEvent(); + await Task.Delay(500); var traces = _logger.GetLogMessages(); Assert.True(traces.Any(m => string.Equals(m.FormattedMessage, "InvocationResponse received for invocation id: 'TestInvocationId'"))); } [Fact] - public void ReceivesInboundEvent_FunctionLoadResponse() + public async Task ReceivesInboundEvent_FunctionLoadResponse() { + await CreateDefaultWorkerChannel(); var functionMetadatas = GetTestFunctionsList("node"); _workerChannel.SetupFunctionInvocationBuffers(functionMetadatas); + _testFunctionRpcService.OnMessage(StreamingMessage.ContentOneofCase.FunctionLoadRequest, + _ => _testFunctionRpcService.PublishFunctionLoadResponseEvent("TestFunctionId1")); _workerChannel.SendFunctionLoadRequests(null, TimeSpan.FromMinutes(5)); - _testFunctionRpcService.PublishFunctionLoadResponseEvent("TestFunctionId1"); + + await Task.Delay(500); var traces = _logger.GetLogMessages(); - Assert.True(traces.Any(m => string.Equals(m.FormattedMessage, "Setting up FunctionInvocationBuffer for function: 'js1' with functionId: 'TestFunctionId1'"))); - Assert.True(traces.Any(m => string.Equals(m.FormattedMessage, "Setting up FunctionInvocationBuffer for function: 'js2' with functionId: 'TestFunctionId2'"))); - Assert.True(traces.Any(m => string.Equals(m.FormattedMessage, "Received FunctionLoadResponse for function: 'js1' with functionId: 'TestFunctionId1'."))); + ShowOutput(traces); + + Assert.True(traces.Any(m => string.Equals(m.FormattedMessage, "Setting up FunctionInvocationBuffer for function: 'js1' with functionId: 'TestFunctionId1'")), "FunctionInvocationBuffer TestFunctionId1"); + Assert.True(traces.Any(m => string.Equals(m.FormattedMessage, "Setting up FunctionInvocationBuffer for function: 'js2' with functionId: 'TestFunctionId2'")), "FunctionInvocationBuffer TestFunctionId2"); + Assert.True(traces.Any(m => string.Equals(m.FormattedMessage, "Received FunctionLoadResponse for function: 'js1' with functionId: 'TestFunctionId1'.")), "FunctionLoadResponse TestFunctionId1"); } [Fact] - public void ReceivesInboundEvent_Failed_FunctionLoadResponses() + public async Task ReceivesInboundEvent_Failed_FunctionLoadResponses() { IDictionary capabilities = new Dictionary() { { RpcWorkerConstants.SupportsLoadResponseCollection, "1" } }; + await CreateDefaultWorkerChannel(capabilities: capabilities); StartStream startStream = new StartStream() { @@ -705,7 +736,6 @@ public void ReceivesInboundEvent_Failed_FunctionLoadResponses() GrpcEvent rpcEvent = new GrpcEvent(_workerId, startStreamMessage); _workerChannel.SendWorkerInitRequest(rpcEvent); - _testFunctionRpcService.PublishWorkerInitResponseEvent(capabilities); var functionMetadatas = GetTestFunctionsList("node"); _workerChannel.SetupFunctionInvocationBuffers(functionMetadatas); @@ -713,20 +743,23 @@ public void ReceivesInboundEvent_Failed_FunctionLoadResponses() _testFunctionRpcService.PublishFunctionLoadResponsesEvent( new List() { "TestFunctionId1", "TestFunctionId2" }, new StatusResult() { Status = StatusResult.Types.Status.Failure }); + + await Task.Delay(500); var traces = _logger.GetLogMessages(); - Assert.True(traces.Any(m => string.Equals(m.FormattedMessage, "Setting up FunctionInvocationBuffer for function: 'js1' with functionId: 'TestFunctionId1'"))); - Assert.True(traces.Any(m => string.Equals(m.FormattedMessage, "Setting up FunctionInvocationBuffer for function: 'js2' with functionId: 'TestFunctionId2'"))); - Assert.True(traces.Any(m => string.Equals(m.FormattedMessage, "Worker failed to load function: 'js1' with function id: 'TestFunctionId1'."))); - Assert.True(traces.Any(m => string.Equals(m.FormattedMessage, "Worker failed to load function: 'js2' with function id: 'TestFunctionId2'."))); + Assert.True(traces.Any(m => string.Equals(m.FormattedMessage, "Setting up FunctionInvocationBuffer for function: 'js1' with functionId: 'TestFunctionId1'")), "setup TestFunctionId1"); + Assert.True(traces.Any(m => string.Equals(m.FormattedMessage, "Setting up FunctionInvocationBuffer for function: 'js2' with functionId: 'TestFunctionId2'")), "setup TestFunctionId2"); + Assert.True(traces.Any(m => string.Equals(m.FormattedMessage, "Worker failed to load function: 'js1' with function id: 'TestFunctionId1'.")), "fail TestFunctionId1"); + Assert.True(traces.Any(m => string.Equals(m.FormattedMessage, "Worker failed to load function: 'js2' with function id: 'TestFunctionId2'.")), "fail TestFunctionId2"); } [Fact] - public void ReceivesInboundEvent_FunctionLoadResponses() + public async Task ReceivesInboundEvent_FunctionLoadResponses() { IDictionary capabilities = new Dictionary() { { RpcWorkerConstants.SupportsLoadResponseCollection, "1" } }; + await CreateDefaultWorkerChannel(capabilities: capabilities); StartStream startStream = new StartStream() { @@ -740,7 +773,6 @@ public void ReceivesInboundEvent_FunctionLoadResponses() GrpcEvent rpcEvent = new GrpcEvent(_workerId, startStreamMessage); _workerChannel.SendWorkerInitRequest(rpcEvent); - _testFunctionRpcService.PublishWorkerInitResponseEvent(capabilities); var functionMetadatas = GetTestFunctionsList("node"); _workerChannel.SetupFunctionInvocationBuffers(functionMetadatas); @@ -748,85 +780,117 @@ public void ReceivesInboundEvent_FunctionLoadResponses() _testFunctionRpcService.PublishFunctionLoadResponsesEvent( new List() { "TestFunctionId1", "TestFunctionId2" }, new StatusResult() { Status = StatusResult.Types.Status.Success }); + + await Task.Delay(500); var traces = _logger.GetLogMessages(); - Assert.True(traces.Any(m => string.Equals(m.FormattedMessage, "Setting up FunctionInvocationBuffer for function: 'js1' with functionId: 'TestFunctionId1'"))); - Assert.True(traces.Any(m => string.Equals(m.FormattedMessage, "Setting up FunctionInvocationBuffer for function: 'js2' with functionId: 'TestFunctionId2'"))); - Assert.True(traces.Any(m => string.Equals(m.FormattedMessage, string.Format("Received FunctionLoadResponseCollection with number of functions: '{0}'.", functionMetadatas.ToList().Count)))); - Assert.True(traces.Any(m => string.Equals(m.FormattedMessage, "Received FunctionLoadResponse for function: 'js1' with functionId: 'TestFunctionId1'."))); - Assert.True(traces.Any(m => string.Equals(m.FormattedMessage, "Received FunctionLoadResponse for function: 'js2' with functionId: 'TestFunctionId2'."))); + Assert.True(traces.Any(m => string.Equals(m.FormattedMessage, "Setting up FunctionInvocationBuffer for function: 'js1' with functionId: 'TestFunctionId1'")), "setup TestFunctionId1"); + Assert.True(traces.Any(m => string.Equals(m.FormattedMessage, "Setting up FunctionInvocationBuffer for function: 'js2' with functionId: 'TestFunctionId2'")), "setup TestFunctionId2"); + Assert.True(traces.Any(m => string.Equals(m.FormattedMessage, string.Format("Received FunctionLoadResponseCollection with number of functions: '{0}'.", functionMetadatas.ToList().Count))), "recv FunctionLoadResponseCollection"); + Assert.True(traces.Any(m => string.Equals(m.FormattedMessage, "Received FunctionLoadResponse for function: 'js1' with functionId: 'TestFunctionId1'.")), "rev TestFunctionId1"); + Assert.True(traces.Any(m => string.Equals(m.FormattedMessage, "Received FunctionLoadResponse for function: 'js2' with functionId: 'TestFunctionId2'.")), "rev TestFunctionId2"); } [Fact] - public void ReceivesInboundEvent_Successful_FunctionMetadataResponse() + public async Task ReceivesInboundEvent_Successful_FunctionMetadataResponse() { + await CreateDefaultWorkerChannel(); var functionMetadata = GetTestFunctionsList("python"); - var functions = _workerChannel.GetFunctionMetadata(); var functionId = "id123"; - _testFunctionRpcService.PublishWorkerMetadataResponse("TestFunctionId1", functionId, functionMetadata, true); + _testFunctionRpcService.OnMessage(StreamingMessage.ContentOneofCase.FunctionsMetadataRequest, + _ => _testFunctionRpcService.PublishWorkerMetadataResponse(_workerId, functionId, functionMetadata, true)); + var functions = _workerChannel.GetFunctionMetadata(); + + await Task.Delay(500); var traces = _logger.GetLogMessages(); + ShowOutput(traces); Assert.True(traces.Any(m => string.Equals(m.FormattedMessage, $"Received the worker function metadata response from worker {_workerChannel.Id}"))); } [Fact] - public void ReceivesInboundEvent_Successful_FunctionMetadataResponse_UseDefaultMetadataIndexing_True() + public async Task ReceivesInboundEvent_Successful_FunctionMetadataResponse_UseDefaultMetadataIndexing_True() { + await CreateDefaultWorkerChannel(); var functionMetadata = GetTestFunctionsList("python"); - var functions = _workerChannel.GetFunctionMetadata(); var functionId = "id123"; - _testFunctionRpcService.PublishWorkerMetadataResponse("TestFunctionId1", functionId, functionMetadata, true, useDefaultMetadataIndexing: true); + _testFunctionRpcService.OnMessage(StreamingMessage.ContentOneofCase.FunctionsMetadataRequest, + _ => _testFunctionRpcService.PublishWorkerMetadataResponse(_workerId, functionId, functionMetadata, true, useDefaultMetadataIndexing: true)); + var functions = _workerChannel.GetFunctionMetadata(); + + await Task.Delay(500); var traces = _logger.GetLogMessages(); + ShowOutput(traces); Assert.True(traces.Any(m => string.Equals(m.FormattedMessage, $"Received the worker function metadata response from worker {_workerChannel.Id}"))); } [Fact] - public void ReceivesInboundEvent_Successful_FunctionMetadataResponse_UseDefaultMetadataIndexing_False() + public async Task ReceivesInboundEvent_Successful_FunctionMetadataResponse_UseDefaultMetadataIndexing_False() { + await CreateDefaultWorkerChannel(); var functionMetadata = GetTestFunctionsList("python"); - var functions = _workerChannel.GetFunctionMetadata(); var functionId = "id123"; - _testFunctionRpcService.PublishWorkerMetadataResponse("TestFunctionId1", functionId, functionMetadata, true, useDefaultMetadataIndexing: false); + _testFunctionRpcService.OnMessage(StreamingMessage.ContentOneofCase.FunctionsMetadataRequest, + _ => _testFunctionRpcService.PublishWorkerMetadataResponse(_workerId, functionId, functionMetadata, true, useDefaultMetadataIndexing: false)); + var functions = _workerChannel.GetFunctionMetadata(); + + await Task.Delay(500); var traces = _logger.GetLogMessages(); + ShowOutput(traces); Assert.True(traces.Any(m => string.Equals(m.FormattedMessage, $"Received the worker function metadata response from worker {_workerChannel.Id}"))); } [Fact] - public void ReceivesInboundEvent_Failed_UseDefaultMetadataIndexing_True_HostIndexing() + public async Task ReceivesInboundEvent_Failed_UseDefaultMetadataIndexing_True_HostIndexing() { + await CreateDefaultWorkerChannel(); var functionMetadata = GetTestFunctionsList("python"); - var functions = _workerChannel.GetFunctionMetadata(); var functionId = "id123"; - _testFunctionRpcService.PublishWorkerMetadataResponse("TestFunctionId1", functionId, functionMetadata, false, useDefaultMetadataIndexing: true); + _testFunctionRpcService.OnMessage(StreamingMessage.ContentOneofCase.FunctionsMetadataRequest, + _ => _testFunctionRpcService.PublishWorkerMetadataResponse(_workerId, functionId, functionMetadata, false, useDefaultMetadataIndexing: true)); + var functions = _workerChannel.GetFunctionMetadata(); + await Task.Delay(500); var traces = _logger.GetLogMessages(); + ShowOutput(traces); Assert.True(traces.Any(m => string.Equals(m.FormattedMessage, $"Received the worker function metadata response from worker {_workerChannel.Id}"))); } [Fact] - public void ReceivesInboundEvent_Failed_UseDefaultMetadataIndexing_False_WorkerIndexing() + public async Task ReceivesInboundEvent_Failed_UseDefaultMetadataIndexing_False_WorkerIndexing() { + await CreateDefaultWorkerChannel(); var functionMetadata = GetTestFunctionsList("python"); - var functions = _workerChannel.GetFunctionMetadata(); var functionId = "id123"; - _testFunctionRpcService.PublishWorkerMetadataResponse("TestFunctionId1", functionId, functionMetadata, false, useDefaultMetadataIndexing: false); + _testFunctionRpcService.OnMessage(StreamingMessage.ContentOneofCase.FunctionsMetadataRequest, + _ => _testFunctionRpcService.PublishWorkerMetadataResponse(_workerId, functionId, functionMetadata, false, useDefaultMetadataIndexing: false)); + var functions = _workerChannel.GetFunctionMetadata(); + await Task.Delay(500); var traces = _logger.GetLogMessages(); + ShowOutput(traces); Assert.True(traces.Any(m => string.Equals(m.FormattedMessage, $"Worker failed to index function {functionId}"))); } [Fact] - public void ReceivesInboundEvent_Failed_FunctionMetadataResponse() + public async Task ReceivesInboundEvent_Failed_FunctionMetadataResponse() { + await CreateDefaultWorkerChannel(); + var functionId = "id123"; var functionMetadata = GetTestFunctionsList("python"); + _testFunctionRpcService.OnMessage(StreamingMessage.ContentOneofCase.FunctionsMetadataRequest, + _ => _testFunctionRpcService.PublishWorkerMetadataResponse(_workerId, functionId, functionMetadata, false)); var functions = _workerChannel.GetFunctionMetadata(); - var functionId = "id123"; - _testFunctionRpcService.PublishWorkerMetadataResponse("TestFunctionId1", functionId, functionMetadata, false); + await Task.Delay(500); var traces = _logger.GetLogMessages(); + ShowOutput(traces); Assert.True(traces.Any(m => string.Equals(m.FormattedMessage, $"Worker failed to index function {functionId}"))); } [Fact] - public void ReceivesInboundEvent_Failed_OverallFunctionMetadataResponse() + public async Task ReceivesInboundEvent_Failed_OverallFunctionMetadataResponse() { + await CreateDefaultWorkerChannel(); + _testFunctionRpcService.OnMessage(StreamingMessage.ContentOneofCase.FunctionsMetadataRequest, + _ => _testFunctionRpcService.PublishWorkerMetadataResponse("TestFunctionId1", null, null, false, false, false)); var functions = _workerChannel.GetFunctionMetadata(); - _testFunctionRpcService.PublishWorkerMetadataResponse("TestFunctionId1", null, null, false, false, false); + await Task.Delay(500); var traces = _logger.GetLogMessages(); Assert.True(traces.Any(m => string.Equals(m.FormattedMessage, $"Worker failed to index functions"))); } @@ -834,6 +898,7 @@ public void ReceivesInboundEvent_Failed_OverallFunctionMetadataResponse() [Fact] public void FunctionLoadRequest_IsExpected() { + CreateDefaultWorkerChannel(); FunctionMetadata metadata = new FunctionMetadata() { Language = "node", @@ -850,11 +915,11 @@ public void FunctionLoadRequest_IsExpected() /// Verify that shared memory data transfer is enabled if all required settings are set. /// [Fact] - public void SharedMemoryDataTransferSetting_VerifyEnabled() + public async Task SharedMemoryDataTransferSetting_VerifyEnabled() { - EnableSharedMemoryDataTransfer(); - - Assert.True(_workerChannel.IsSharedMemoryDataTransferEnabled()); + await CreateSharedMemoryEnabledWorkerChannel(); + await Task.Delay(500); + Assert.True(_workerChannel.IsSharedMemoryDataTransferEnabled(), "shared memory should be enabled"); } /// @@ -863,6 +928,7 @@ public void SharedMemoryDataTransferSetting_VerifyEnabled() [Fact] public void SharedMemoryDataTransferSetting_VerifyDisabled() { + CreateDefaultWorkerChannel(); Assert.False(_workerChannel.IsSharedMemoryDataTransferEnabled()); } @@ -875,21 +941,7 @@ public void SharedMemoryDataTransferSetting_VerifyDisabledIfWorkerCapabilityAbse { // Enable shared memory data transfer in the environment _testEnvironment.SetEnvironmentVariable(RpcWorkerConstants.FunctionsWorkerSharedMemoryDataTransferEnabledSettingName, "1"); - - StartStream startStream = new StartStream() - { - WorkerId = _workerId - }; - - StreamingMessage startStreamMessage = new StreamingMessage() - { - StartStream = startStream - }; - - // Send worker init request and enable the capabilities - GrpcEvent rpcEvent = new GrpcEvent(_workerId, startStreamMessage); - _workerChannel.SendWorkerInitRequest(rpcEvent); - _testFunctionRpcService.PublishWorkerInitResponseEvent(); + CreateDefaultWorkerChannel(); Assert.False(_workerChannel.IsSharedMemoryDataTransferEnabled()); } @@ -901,33 +953,14 @@ public void SharedMemoryDataTransferSetting_VerifyDisabledIfWorkerCapabilityAbse [Fact] public void SharedMemoryDataTransferSetting_VerifyDisabledIfEnvironmentVariableAbsent() { - // Enable shared memory data transfer capability in the worker - IDictionary capabilities = new Dictionary() - { - { RpcWorkerConstants.SharedMemoryDataTransfer, "1" } - }; - - StartStream startStream = new StartStream() - { - WorkerId = _workerId - }; - - StreamingMessage startStreamMessage = new StreamingMessage() - { - StartStream = startStream - }; - - // Send worker init request and enable the capabilities - GrpcEvent rpcEvent = new GrpcEvent(_workerId, startStreamMessage); - _workerChannel.SendWorkerInitRequest(rpcEvent); - _testFunctionRpcService.PublishWorkerInitResponseEvent(capabilities); - + CreateSharedMemoryEnabledWorkerChannel(setEnvironmentVariable: false); Assert.False(_workerChannel.IsSharedMemoryDataTransferEnabled()); } [Fact] public async Task GetLatencies_StartsTimer_WhenDynamicConcurrencyEnabled() { + // note: uses custom worker channel RpcWorkerConfig config = new RpcWorkerConfig() { Description = new RpcWorkerDescription() @@ -966,6 +999,7 @@ await TestHelpers.Await(() => [Fact] public async Task GetLatencies_DoesNot_StartTimer_WhenDynamicConcurrencyDisabled() { + // note: uses custom worker channels RpcWorkerConfig config = new RpcWorkerConfig() { Description = new RpcWorkerDescription() @@ -993,57 +1027,81 @@ public async Task GetLatencies_DoesNot_StartTimer_WhenDynamicConcurrencyDisabled IEnumerable latencyHistory = workerChannel.GetLatencies(); - Assert.True(latencyHistory.Count() == 0); + Assert.Equal(0, latencyHistory.Count()); } [Fact] public async Task SendInvocationRequest_ValidateTraceContext() { + await CreateDefaultWorkerChannel(); ScriptInvocationContext scriptInvocationContext = GetTestScriptInvocationContext(Guid.NewGuid(), null); - await _workerChannelWithMockEventManager.SendInvocationRequest(scriptInvocationContext); + + await _workerChannel.SendInvocationRequest(scriptInvocationContext); + + RpcTraceContext ctx = null; + _testFunctionRpcService.OnMessage(StreamingMessage.ContentOneofCase.InvocationRequest, evt => + { + ctx = evt.Message.InvocationRequest.TraceContext; + }); + await Task.Delay(500); + + Assert.NotNull(ctx); + var attribs = ctx.Attributes; + Assert.NotNull(attribs); + if (_testEnvironment.IsApplicationInsightsAgentEnabled()) { - _eventManagerMock.Verify(proxy => proxy.Publish(It.Is( - grpcEvent => grpcEvent.Message.InvocationRequest.TraceContext.Attributes.ContainsKey(ScriptConstants.LogPropertyProcessIdKey) - && grpcEvent.Message.InvocationRequest.TraceContext.Attributes.ContainsKey(ScriptConstants.LogPropertyHostInstanceIdKey) - && grpcEvent.Message.InvocationRequest.TraceContext.Attributes.ContainsKey(LogConstants.CategoryNameKey) - && grpcEvent.Message.InvocationRequest.TraceContext.Attributes[LogConstants.CategoryNameKey].Equals("testcat1") - && grpcEvent.Message.InvocationRequest.TraceContext.Attributes.Count == 3))); + _testOutput.WriteLine("Checking ENABLED app-insights fields..."); + Assert.True(attribs.ContainsKey(ScriptConstants.LogPropertyProcessIdKey), "ScriptConstants.LogPropertyProcessIdKey"); + Assert.True(attribs.ContainsKey(ScriptConstants.LogPropertyHostInstanceIdKey), "ScriptConstants.LogPropertyHostInstanceIdKey"); + Assert.True(attribs.TryGetValue(LogConstants.CategoryNameKey, out var catKey), "LogConstants.CategoryNameKey"); + Assert.Equal(catKey, "testcat1"); + Assert.Equal(3, attribs.Count); } else { - _eventManagerMock.Verify(proxy => proxy.Publish(It.Is( - grpcEvent => !grpcEvent.Message.InvocationRequest.TraceContext.Attributes.ContainsKey(ScriptConstants.LogPropertyProcessIdKey) - && !grpcEvent.Message.InvocationRequest.TraceContext.Attributes.ContainsKey(ScriptConstants.LogPropertyHostInstanceIdKey) - && !grpcEvent.Message.InvocationRequest.TraceContext.Attributes.ContainsKey(LogConstants.CategoryNameKey)))); + _testOutput.WriteLine("Checking DISABLED app-insights fields..."); + Assert.False(attribs.ContainsKey(ScriptConstants.LogPropertyProcessIdKey), "ScriptConstants.LogPropertyProcessIdKey"); + Assert.False(attribs.ContainsKey(ScriptConstants.LogPropertyHostInstanceIdKey), "ScriptConstants.LogPropertyHostInstanceIdKey"); + Assert.False(attribs.ContainsKey(LogConstants.CategoryNameKey), "LogConstants.CategoryNameKey"); + Assert.Equal(0, attribs.Count); } } [Fact] public async Task SendInvocationRequest_ValidateTraceContext_SessionId() { + await CreateDefaultWorkerChannel(); string sessionId = "sessionId1234"; Activity activity = new Activity("testActivity"); activity.AddBaggage(ScriptConstants.LiveLogsSessionAIKey, sessionId); activity.Start(); ScriptInvocationContext scriptInvocationContext = GetTestScriptInvocationContext(Guid.NewGuid(), null); - await _workerChannelWithMockEventManager.SendInvocationRequest(scriptInvocationContext); + + OutboundGrpcEvent grpcEvent = null; + _testFunctionRpcService.OnMessage(StreamingMessage.ContentOneofCase.InvocationRequest, evt => + { + grpcEvent = evt; + }); + await _workerChannel.SendInvocationRequest(scriptInvocationContext); + await Task.Delay(500); + + Assert.NotNull(grpcEvent); + activity.Stop(); - _eventManagerMock.Verify(p => p.Publish(It.Is(grpcEvent => ValidateInvocationRequest(grpcEvent, sessionId)))); - } + var attribs = grpcEvent.Message.InvocationRequest.TraceContext.Attributes; - private bool ValidateInvocationRequest(OutboundGrpcEvent grpcEvent, string sessionId) - { if (_testEnvironment.IsApplicationInsightsAgentEnabled()) { - return grpcEvent.Message.InvocationRequest.TraceContext.Attributes[ScriptConstants.LiveLogsSessionAIKey].Equals(sessionId) - && grpcEvent.Message.InvocationRequest.TraceContext.Attributes.ContainsKey(LogConstants.CategoryNameKey) - && grpcEvent.Message.InvocationRequest.TraceContext.Attributes[LogConstants.CategoryNameKey].Equals("testcat1") - && grpcEvent.Message.InvocationRequest.TraceContext.Attributes.Count == 4; + Assert.True(attribs.TryGetValue(ScriptConstants.LiveLogsSessionAIKey, out var aiKey), "ScriptConstants.LiveLogsSessionAIKey"); + Assert.Equal(sessionId, aiKey); + Assert.True(attribs.TryGetValue(LogConstants.CategoryNameKey, out var catKey), "LogConstants.CategoryNameKey"); + Assert.Equal("testcat1", catKey); + Assert.Equal(4, attribs.Count); } else { - return !grpcEvent.Message.InvocationRequest.TraceContext.Attributes.ContainsKey(LogConstants.CategoryNameKey); + Assert.False(attribs.ContainsKey(LogConstants.CategoryNameKey), "LogConstants.CategoryNameKey"); } } @@ -1153,31 +1211,21 @@ private bool AreExpectedMetricsGenerated() return _metricsLogger.EventsBegan.Contains(MetricEventNames.FunctionLoadRequestResponse); } - private void EnableSharedMemoryDataTransfer() + private Task CreateSharedMemoryEnabledWorkerChannel(bool setEnvironmentVariable = true) { - // Enable shared memory data transfer in the environment - _testEnvironment.SetEnvironmentVariable(RpcWorkerConstants.FunctionsWorkerSharedMemoryDataTransferEnabledSettingName, "1"); + if (setEnvironmentVariable) + { + // Enable shared memory data transfer in the environment + _testEnvironment.SetEnvironmentVariable(RpcWorkerConstants.FunctionsWorkerSharedMemoryDataTransferEnabledSettingName, "1"); + } // Enable shared memory data transfer capability in the worker IDictionary capabilities = new Dictionary() { { RpcWorkerConstants.SharedMemoryDataTransfer, "1" } }; - - StartStream startStream = new StartStream() - { - WorkerId = _workerId - }; - - StreamingMessage startStreamMessage = new StreamingMessage() - { - StartStream = startStream - }; - // Send worker init request and enable the capabilities - GrpcEvent rpcEvent = new GrpcEvent(_workerId, startStreamMessage); - _workerChannel.SendWorkerInitRequest(rpcEvent); - _testFunctionRpcService.PublishWorkerInitResponseEvent(capabilities); + return CreateDefaultWorkerChannel(capabilities: capabilities); } } } diff --git a/test/WebJobs.Script.Tests/Workers/Rpc/RpcWorkerConfigFactoryTests.cs b/test/WebJobs.Script.Tests/Workers/Rpc/RpcWorkerConfigFactoryTests.cs index baea4e1087..517d6deac0 100644 --- a/test/WebJobs.Script.Tests/Workers/Rpc/RpcWorkerConfigFactoryTests.cs +++ b/test/WebJobs.Script.Tests/Workers/Rpc/RpcWorkerConfigFactoryTests.cs @@ -92,6 +92,12 @@ public void LanguageWorker_WorkersDir_NotSet() [Fact] public void JavaPath_FromEnvVars() { + var javaHome = Environment.GetEnvironmentVariable("JAVA_HOME"); + if (string.IsNullOrWhiteSpace(javaHome)) + { + // if the var doesn't exist, set something temporary to make it at least work + Environment.SetEnvironmentVariable("JAVA_HOME", Path.GetTempPath()); + } var configBuilder = ScriptSettingsManager.CreateDefaultConfigurationBuilder(); var config = configBuilder.Build(); var scriptSettingsManager = new ScriptSettingsManager(config); diff --git a/test/WebJobs.Script.Tests/Workers/Rpc/TestFunctionRpcService.cs b/test/WebJobs.Script.Tests/Workers/Rpc/TestFunctionRpcService.cs index 45c79eb69e..080e7d7b8d 100644 --- a/test/WebJobs.Script.Tests/Workers/Rpc/TestFunctionRpcService.cs +++ b/test/WebJobs.Script.Tests/Workers/Rpc/TestFunctionRpcService.cs @@ -2,46 +2,138 @@ // Licensed under the MIT License. See License.txt in the project root for license information. using System; +using System.Collections.Concurrent; using System.Collections.Generic; -using System.Linq; -using System.Reactive.Linq; -using System.Runtime.InteropServices; +using System.Threading.Channels; using System.Threading.Tasks; using Microsoft.Azure.WebJobs.Script.Description; using Microsoft.Azure.WebJobs.Script.Diagnostics; using Microsoft.Azure.WebJobs.Script.Eventing; -using Microsoft.Azure.WebJobs.Script.Grpc; using Microsoft.Azure.WebJobs.Script.Grpc.Eventing; using Microsoft.Azure.WebJobs.Script.Grpc.Messages; -using Microsoft.Azure.WebJobs.Script.Workers; -using Microsoft.Azure.WebJobs.Script.Workers.Rpc; -using Microsoft.Azure.WebJobs.Script.Workers.SharedMemoryDataTransfer; using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Logging.Abstractions; -using Microsoft.Extensions.Options; -using Moq; -using Xunit; namespace Microsoft.Azure.WebJobs.Script.Tests.Workers.Rpc { public class TestFunctionRpcService { - private IScriptEventManager _eventManager; private ILogger _logger; private string _workerId; private IDictionary _outboundEventSubscriptions = new Dictionary(); + private ChannelWriter _inboundWriter; + private ConcurrentDictionary> _handlers = new ConcurrentDictionary>(); public TestFunctionRpcService(IScriptEventManager eventManager, string workerId, TestLogger logger, string expectedLogMsg = "") { - _eventManager = eventManager; _logger = logger; _workerId = workerId; - _outboundEventSubscriptions.Add(workerId, _eventManager.OfType() - .Where(evt => evt.WorkerId == workerId) - .Subscribe(evt => - { - _logger.LogInformation(expectedLogMsg); - })); + if (eventManager.TryGetGrpcChannels(workerId, out var inbound, out var outbound)) + { + _ = ListenAsync(outbound.Reader, expectedLogMsg); + _inboundWriter = inbound.Writer; + + PublishStartStreamEvent(); // simulate the start-stream immediately + } + } + + public void OnMessage(StreamingMessage.ContentOneofCase messageType, Action callback) + => _handlers.AddOrUpdate(messageType, callback, (messageType, oldValue) => oldValue + callback); + + public void AutoReply(StreamingMessage.ContentOneofCase messageType) + { + // apply standard default responses + Action callback = messageType switch + { + StreamingMessage.ContentOneofCase.FunctionEnvironmentReloadRequest => _ => PublishFunctionEnvironmentReloadResponseEvent(), + _ => null, + }; + if (callback is not null) + { + OnMessage(messageType, callback); + } + } + + private void OnMessage(OutboundGrpcEvent message) + { + if (_handlers.TryRemove(message.MessageType, out var action)) + { + try + { + _logger.LogDebug("[service] invoking auto-reply for {0}, {1}: {2}", _workerId, message.MessageType, action?.Method?.Name); + action?.Invoke(message); + } + catch (Exception ex) + { + _logger.LogError(ex.Message); + } + } + } + + private async Task ListenAsync(ChannelReader source, string expectedLogMsg) + { + await Task.Yield(); // free up caller + try + { + while (await source.WaitToReadAsync()) + { + while (source.TryRead(out var evt)) + { + _logger.LogDebug("[service] received {0}, {1}", evt.WorkerId, evt.MessageType); + _logger.LogInformation(expectedLogMsg); + + OnMessage(evt); + } + } + } + catch + { + } + } + + private ValueTask WriteAsync(StreamingMessage message) + => _inboundWriter is null ? default + : _inboundWriter.WriteAsync(new InboundGrpcEvent(_workerId, message)); + + private void Write(StreamingMessage message) + { + if (_inboundWriter is null) + { + _logger.LogDebug("[service] no writer for {0}, {1}", _workerId, message.ContentCase); + return; + } + var evt = new InboundGrpcEvent(_workerId, message); + _logger.LogDebug("[service] sending {0}, {1}", evt.WorkerId, evt.MessageType); + if (_inboundWriter.TryWrite(evt)) + { + return; + } + var vt = _inboundWriter.WriteAsync(evt); + if (vt.IsCompleted) + { + try + { + vt.GetAwaiter().GetResult(); + } + catch (Exception ex) + { + _logger.LogError(ex.Message); + } + } + else + { + _ = ObserveEventually(vt, _logger); + } + static async Task ObserveEventually(ValueTask valueTask, ILogger logger) + { + try + { + await valueTask; + } + catch (Exception ex) + { + logger.LogError(ex.Message); + } + } } public void PublishFunctionLoadResponseEvent(string functionId) @@ -59,7 +151,7 @@ public void PublishFunctionLoadResponseEvent(string functionId) { FunctionLoadResponse = functionLoadResponse }; - _eventManager.Publish(new InboundGrpcEvent(_workerId, responseMessage)); + Write(responseMessage); } public void PublishFunctionLoadResponsesEvent(List functionIds, StatusResult statusResult) @@ -81,7 +173,7 @@ public void PublishFunctionLoadResponsesEvent(List functionIds, StatusRe { FunctionLoadResponseCollection = functionLoadResponseCollection }; - _eventManager.Publish(new InboundGrpcEvent(_workerId, responseMessage)); + Write(responseMessage); } public void PublishFunctionEnvironmentReloadResponseEvent() @@ -91,7 +183,7 @@ public void PublishFunctionEnvironmentReloadResponseEvent() { FunctionEnvironmentReloadResponse = relaodEnvResponse }; - _eventManager.Publish(new InboundGrpcEvent(_workerId, responseMessage)); + Write(responseMessage); } public void PublishWorkerInitResponseEvent(IDictionary capabilities = null, WorkerMetadata workerMetadata = null) @@ -121,10 +213,10 @@ public void PublishWorkerInitResponseEvent(IDictionary capabilit WorkerInitResponse = initResponse }; - _eventManager.Publish(new InboundGrpcEvent(_workerId, responseMessage)); + Write(responseMessage); } - public void PublishWorkerInitResponseEventWithSharedMemoryDataTransferCapability() + private void PublishWorkerInitResponseEventWithSharedMemoryDataTransferCapability() { StatusResult statusResult = new StatusResult() { @@ -138,7 +230,7 @@ public void PublishWorkerInitResponseEventWithSharedMemoryDataTransferCapability { WorkerInitResponse = initResponse }; - _eventManager.Publish(new InboundGrpcEvent(_workerId, responseMessage)); + Write(responseMessage); } public void PublishSystemLogEvent(RpcLog.Types.Level inputLevel) @@ -154,7 +246,7 @@ public void PublishSystemLogEvent(RpcLog.Types.Level inputLevel) { RpcLog = rpcLog }; - _eventManager.Publish(new InboundGrpcEvent(_workerId, logMessage)); + Write(logMessage); } public static FunctionEnvironmentReloadResponse GetTestFunctionEnvReloadResponse() @@ -185,10 +277,10 @@ public void PublishInvocationResponseEvent() { InvocationResponse = invocationResponse }; - _eventManager.Publish(new InboundGrpcEvent(_workerId, responseMessage)); + Write(responseMessage); } - public void PublishStartStreamEvent(string workerId) + private void PublishStartStreamEvent() { StatusResult statusResult = new StatusResult() { @@ -196,13 +288,13 @@ public void PublishStartStreamEvent(string workerId) }; StartStream startStream = new StartStream() { - WorkerId = workerId + WorkerId = _workerId }; StreamingMessage responseMessage = new StreamingMessage() { StartStream = startStream }; - _eventManager.Publish(new InboundGrpcEvent(_workerId, responseMessage)); + Write(responseMessage); } public void PublishWorkerMetadataResponse(string workerId, string functionId, IEnumerable functionMetadata, bool successful, bool useDefaultMetadataIndexing = false, bool overallStatus = true) @@ -245,7 +337,7 @@ public void PublishWorkerMetadataResponse(string workerId, string functionId, IE { FunctionMetadataResponse = overallResponse }; - _eventManager.Publish(new InboundGrpcEvent(_workerId, responseMessage)); + Write(responseMessage); } } } \ No newline at end of file