diff --git a/Test/DurableTask.Core.Tests/DispatcherMiddlewareTests.cs b/Test/DurableTask.Core.Tests/DispatcherMiddlewareTests.cs new file mode 100644 index 000000000..c1b115a5a --- /dev/null +++ b/Test/DurableTask.Core.Tests/DispatcherMiddlewareTests.cs @@ -0,0 +1,165 @@ +// ---------------------------------------------------------------------------------- +// Copyright Microsoft Corporation +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ---------------------------------------------------------------------------------- + +namespace DurableTask.Core.Tests +{ + using System; + using System.Diagnostics; + using System.Text; + using System.Threading.Tasks; + using DurableTask.Core.History; + using DurableTask.Emulator; + using DurableTask.Test.Orchestrations; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class DispatcherMiddlewareTests + { + TaskHubWorker worker; + TaskHubClient client; + + [TestInitialize] + public async Task Initialize() + { + var service = new LocalOrchestrationService(); + this.worker = new TaskHubWorker(service); + + await this.worker + .AddTaskOrchestrations(typeof(SimplestGreetingsOrchestration)) + .AddTaskActivities(typeof(SimplestGetUserTask), typeof(SimplestSendGreetingTask)) + .StartAsync(); + + this.client = new TaskHubClient(service); + } + + [TestCleanup] + public async Task TestCleanup() + { + await this.worker.StopAsync(true); + } + + [TestMethod] + public async Task DispatchMiddlewareContextBuiltInProperties() + { + TaskOrchestration orchestration = null; + OrchestrationRuntimeState state = null; + OrchestrationInstance instance1 = null; + + TaskActivity activity = null; + TaskScheduledEvent taskScheduledEvent = null; + OrchestrationInstance instance2 = null; + + this.worker.AddOrchestrationDispatcherMiddleware((context, next) => + { + orchestration = context.GetProperty(); + state = context.GetProperty(); + instance1 = context.GetProperty(); + + return next(); + }); + + this.worker.AddActivityDispatcherMiddleware((context, next) => + { + activity = context.GetProperty(); + taskScheduledEvent = context.GetProperty(); + instance2 = context.GetProperty(); + + return next(); + }); + + var instance = await this.client.CreateOrchestrationInstanceAsync(typeof(SimplestGreetingsOrchestration), null); + + TimeSpan timeout = TimeSpan.FromSeconds(Debugger.IsAttached ? 1000 : 10); + await this.client.WaitForOrchestrationAsync(instance, timeout); + + Assert.IsNotNull(orchestration); + Assert.IsNotNull(state); + Assert.IsNotNull(instance1); + + Assert.IsNotNull(activity); + Assert.IsNotNull(taskScheduledEvent); + Assert.IsNotNull(instance2); + + Assert.AreNotSame(instance1, instance2); + Assert.AreEqual(instance1.InstanceId, instance2.InstanceId); + } + + [TestMethod] + public async Task OrchestrationDispatcherMiddlewareContextFlow() + { + StringBuilder output = null; + + for (int i = 0; i < 10; i++) + { + string value = i.ToString(); + this.worker.AddOrchestrationDispatcherMiddleware(async (context, next) => + { + output = context.GetProperty("output"); + if (output == null) + { + output = new StringBuilder(); + context.SetProperty("output", output); + } + + output.Append(value); + await next(); + output.Append(value); + }); + } + + var instance = await this.client.CreateOrchestrationInstanceAsync(typeof(SimplestGreetingsOrchestration), null); + + TimeSpan timeout = TimeSpan.FromSeconds(Debugger.IsAttached ? 1000 : 10); + await this.client.WaitForOrchestrationAsync(instance, timeout); + + // Each reply gets a new context, so the output should stay the same regardless of how + // many replays an orchestration goes through. + Assert.IsNotNull(output); + Assert.AreEqual("01234567899876543210", output.ToString()); + } + + [TestMethod] + public async Task ActivityDispatcherMiddlewareContextFlow() + { + StringBuilder output = null; + + for (int i = 0; i < 10; i++) + { + string value = i.ToString(); + this.worker.AddActivityDispatcherMiddleware(async (context, next) => + { + output = context.GetProperty("output"); + if (output == null) + { + output = new StringBuilder(); + context.SetProperty("output", output); + } + + output.Append(value); + await next(); + output.Append(value); + }); + } + + var instance = await this.client.CreateOrchestrationInstanceAsync(typeof(SimplestGreetingsOrchestration), null); + + TimeSpan timeout = TimeSpan.FromSeconds(Debugger.IsAttached ? 1000 : 10); + await this.client.WaitForOrchestrationAsync(instance, timeout); + + // Each actiivty gets a new context, so the output should stay the same regardless of how + // many activities an orchestration schedules (as long as there is at least one). + Assert.IsNotNull(output); + Assert.AreEqual("01234567899876543210", output.ToString()); + } + } +} diff --git a/src/DurableTask.Core/Middleware/DispatchMiddlewareContext.cs b/src/DurableTask.Core/Middleware/DispatchMiddlewareContext.cs new file mode 100644 index 000000000..2d7fa8407 --- /dev/null +++ b/src/DurableTask.Core/Middleware/DispatchMiddlewareContext.cs @@ -0,0 +1,75 @@ +// ---------------------------------------------------------------------------------- +// Copyright Microsoft Corporation +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ---------------------------------------------------------------------------------- + +namespace DurableTask.Core.Middleware +{ + using System; + using System.Collections.Generic; + + /// + /// Context data that can be used to share data between middleware. + /// + public class DispatchMiddlewareContext + { + internal DispatchMiddlewareContext() + { + } + + /// + /// Sets a property value to the context using the full name of the type as the key. + /// + /// The type of the property. + /// The value of the property. + public void SetProperty(T value) + { + this.SetProperty(typeof(T).FullName, value); + } + + /// + /// Sets a named property value to the context. + /// + /// The type of the property. + /// The name of the property. + /// The value of the property. + public void SetProperty(string key, T value) + { + this.Properties[key] = value; + } + + /// + /// Gets a property value from the context using the full name of . + /// + /// The type of the property. + /// The value of the property or default(T) if the property is not defined. + public T GetProperty() + { + return this.GetProperty(typeof(T).FullName); + } + + /// + /// Gets a named property value from the context. + /// + /// + /// The name of the property value. + /// The value of the property or default(T) if the property is not defined. + public T GetProperty(string key) + { + return this.Properties.TryGetValue(key, out object value) ? (T)value : default(T); + } + + /// + /// Gets a key/value collection that can be used to share data between middleware. + /// + public IDictionary Properties { get; } = new Dictionary(StringComparer.Ordinal); + } +} diff --git a/src/DurableTask.Core/Middleware/DispatchMiddlewareDelegate.cs b/src/DurableTask.Core/Middleware/DispatchMiddlewareDelegate.cs new file mode 100644 index 000000000..d11e18771 --- /dev/null +++ b/src/DurableTask.Core/Middleware/DispatchMiddlewareDelegate.cs @@ -0,0 +1,24 @@ +// ---------------------------------------------------------------------------------- +// Copyright Microsoft Corporation +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ---------------------------------------------------------------------------------- + +namespace DurableTask.Core.Middleware +{ + using System.Threading.Tasks; + + /// + /// A function that runs in the task execution middleware pipeline. + /// + /// The for the task execution. + /// A task that represents the completion of the durable task execution. + public delegate Task DispatchMiddlewareDelegate(DispatchMiddlewareContext context); +} diff --git a/src/DurableTask.Core/Middleware/DispatchMiddlewarePipeline.cs b/src/DurableTask.Core/Middleware/DispatchMiddlewarePipeline.cs new file mode 100644 index 000000000..c0468296b --- /dev/null +++ b/src/DurableTask.Core/Middleware/DispatchMiddlewarePipeline.cs @@ -0,0 +1,49 @@ +// ---------------------------------------------------------------------------------- +// Copyright Microsoft Corporation +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ---------------------------------------------------------------------------------- + +namespace DurableTask.Core.Middleware +{ + using System; + using System.Collections.Generic; + using System.Linq; + using System.Threading.Tasks; + + class DispatchMiddlewarePipeline + { + readonly IList> components = + new List>(); + + public Task RunAsync(DispatchMiddlewareContext context, DispatchMiddlewareDelegate handler) + { + // Build the delegate chain + foreach (var component in this.components.Reverse()) + { + handler = component(handler); + } + + return handler(context); + } + + public void Add(Func, Task> middleware) + { + this.components.Add(next => + { + return context => + { + Func simpleNext = () => next(context); + return middleware(context, simpleNext); + }; + }); + } + } +} diff --git a/src/DurableTask.Core/TaskActivityDispatcher.cs b/src/DurableTask.Core/TaskActivityDispatcher.cs index 68fba5b92..d908ba4ae 100644 --- a/src/DurableTask.Core/TaskActivityDispatcher.cs +++ b/src/DurableTask.Core/TaskActivityDispatcher.cs @@ -20,6 +20,7 @@ namespace DurableTask.Core using DurableTask.Core.Common; using DurableTask.Core.Exceptions; using DurableTask.Core.History; + using DurableTask.Core.Middleware; using DurableTask.Core.Tracing; /// @@ -30,23 +31,16 @@ public sealed class TaskActivityDispatcher readonly INameVersionObjectManager objectManager; readonly WorkItemDispatcher dispatcher; readonly IOrchestrationService orchestrationService; - + readonly DispatchMiddlewarePipeline dispatchPipeline; + internal TaskActivityDispatcher( IOrchestrationService orchestrationService, - INameVersionObjectManager objectManager) + INameVersionObjectManager objectManager, + DispatchMiddlewarePipeline dispatchPipeline) { - if (orchestrationService == null) - { - throw new ArgumentNullException(nameof(orchestrationService)); - } - - if (objectManager == null) - { - throw new ArgumentNullException(nameof(objectManager)); - } - - this.orchestrationService = orchestrationService; - this.objectManager = objectManager; + this.orchestrationService = orchestrationService ?? throw new ArgumentNullException(nameof(orchestrationService)); + this.objectManager = objectManager ?? throw new ArgumentNullException(nameof(objectManager)); + this.dispatchPipeline = dispatchPipeline ?? throw new ArgumentNullException(nameof(dispatchPipeline)); this.dispatcher = new WorkItemDispatcher( "TaskActivityDispatcher", @@ -124,25 +118,33 @@ async Task OnProcessWorkItemAsync(TaskActivityWorkItem workItem) var context = new TaskContext(taskMessage.OrchestrationInstance); HistoryEvent eventToRespond = null; - try - { - string output = await taskActivity.RunAsync(context, scheduledEvent.Input); - eventToRespond = new TaskCompletedEvent(-1, scheduledEvent.EventId, output); - } - catch (TaskFailureException e) - { - TraceHelper.TraceExceptionInstance(TraceEventType.Error, taskMessage.OrchestrationInstance, e); - string details = IncludeDetails ? e.Details : null; - eventToRespond = new TaskFailedEvent(-1, scheduledEvent.EventId, e.Message, details); - } - catch (Exception e) when (!Utils.IsFatal(e)) + var dispatchContext = new DispatchMiddlewareContext(); + dispatchContext.SetProperty(taskMessage.OrchestrationInstance); + dispatchContext.SetProperty(taskActivity); + dispatchContext.SetProperty(scheduledEvent); + + await this.dispatchPipeline.RunAsync(dispatchContext, async _ => { - TraceHelper.TraceExceptionInstance(TraceEventType.Error, taskMessage.OrchestrationInstance, e); - string details = IncludeDetails - ? $"Unhandled exception while executing task: {e}\n\t{e.StackTrace}" - : null; - eventToRespond = new TaskFailedEvent(-1, scheduledEvent.EventId, e.Message, details); - } + try + { + string output = await taskActivity.RunAsync(context, scheduledEvent.Input); + eventToRespond = new TaskCompletedEvent(-1, scheduledEvent.EventId, output); + } + catch (TaskFailureException e) + { + TraceHelper.TraceExceptionInstance(TraceEventType.Error, taskMessage.OrchestrationInstance, e); + string details = IncludeDetails ? e.Details : null; + eventToRespond = new TaskFailedEvent(-1, scheduledEvent.EventId, e.Message, details); + } + catch (Exception e) when (!Utils.IsFatal(e)) + { + TraceHelper.TraceExceptionInstance(TraceEventType.Error, taskMessage.OrchestrationInstance, e); + string details = IncludeDetails + ? $"Unhandled exception while executing task: {e}\n\t{e.StackTrace}" + : null; + eventToRespond = new TaskFailedEvent(-1, scheduledEvent.EventId, e.Message, details); + } + }); var responseTaskMessage = new TaskMessage { diff --git a/src/DurableTask.Core/TaskHubWorker.cs b/src/DurableTask.Core/TaskHubWorker.cs index 3a5a5849a..212c25e83 100644 --- a/src/DurableTask.Core/TaskHubWorker.cs +++ b/src/DurableTask.Core/TaskHubWorker.cs @@ -17,7 +17,7 @@ namespace DurableTask.Core using System.Reflection; using System.Threading; using System.Threading.Tasks; - using DurableTask.Core.Settings; + using DurableTask.Core.Middleware; /// /// Allows users to load the TaskOrchestration and TaskActivity classes and start @@ -28,6 +28,9 @@ public sealed class TaskHubWorker : IDisposable readonly INameVersionObjectManager activityManager; readonly INameVersionObjectManager orchestrationManager; + readonly DispatchMiddlewarePipeline orchestrationDispatchPipeline = new DispatchMiddlewarePipeline(); + readonly DispatchMiddlewarePipeline activityDispatchPipeline = new DispatchMiddlewarePipeline(); + readonly SemaphoreSlim slimLock = new SemaphoreSlim(1, 1); /// @@ -78,6 +81,24 @@ public TaskHubWorker( /// public TaskActivityDispatcher TaskActivityDispatcher => activityDispatcher; + /// + /// Adds a middleware delegate to the orchestration dispatch pipeline. + /// + /// Delegate to invoke whenever a message is dispatched to an orchestration. + public void AddOrchestrationDispatcherMiddleware(Func, Task> middleware) + { + orchestrationDispatchPipeline.Add(middleware ?? throw new ArgumentNullException(nameof(middleware))); + } + + /// + /// Adds a middleware delegate to the activity dispatch pipeline. + /// + /// Delegate to invoke whenever a message is dispatched to an activity. + public void AddActivityDispatcherMiddleware(Func, Task> middleware) + { + activityDispatchPipeline.Add(middleware ?? throw new ArgumentNullException(nameof(middleware))); + } + /// /// Starts the TaskHubWorker so it begins processing orchestrations and activities /// @@ -92,8 +113,14 @@ public async Task StartAsync() throw new InvalidOperationException("Worker is already started"); } - orchestrationDispatcher = new TaskOrchestrationDispatcher(orchestrationService, orchestrationManager); - activityDispatcher = new TaskActivityDispatcher(orchestrationService, activityManager); + orchestrationDispatcher = new TaskOrchestrationDispatcher( + orchestrationService, + orchestrationManager, + orchestrationDispatchPipeline); + activityDispatcher = new TaskActivityDispatcher( + orchestrationService, + activityManager, + activityDispatchPipeline); await orchestrationService.StartAsync(); await orchestrationDispatcher.StartAsync(); diff --git a/src/DurableTask.Core/TaskOrchestrationDispatcher.cs b/src/DurableTask.Core/TaskOrchestrationDispatcher.cs index d3c3e36cc..096824976 100644 --- a/src/DurableTask.Core/TaskOrchestrationDispatcher.cs +++ b/src/DurableTask.Core/TaskOrchestrationDispatcher.cs @@ -23,6 +23,7 @@ namespace DurableTask.Core using DurableTask.Core.Common; using DurableTask.Core.Exceptions; using DurableTask.Core.History; + using DurableTask.Core.Middleware; using DurableTask.Core.Serializing; using DurableTask.Core.Tracing; @@ -34,25 +35,18 @@ public class TaskOrchestrationDispatcher readonly INameVersionObjectManager objectManager; readonly IOrchestrationService orchestrationService; readonly WorkItemDispatcher dispatcher; + readonly DispatchMiddlewarePipeline dispatchPipeline; static readonly DataConverter DataConverter = new JsonDataConverter(); internal TaskOrchestrationDispatcher( IOrchestrationService orchestrationService, - INameVersionObjectManager objectManager) + INameVersionObjectManager objectManager, + DispatchMiddlewarePipeline dispatchPipeline) { - if (orchestrationService == null) - { - throw new ArgumentNullException(nameof(orchestrationService)); - } - - if (objectManager == null) - { - throw new ArgumentNullException(nameof(objectManager)); - } - - this.objectManager = objectManager; + this.objectManager = objectManager ?? throw new ArgumentNullException(nameof(objectManager)); + this.orchestrationService = orchestrationService ?? throw new ArgumentNullException(nameof(orchestrationService)); + this.dispatchPipeline = dispatchPipeline ?? throw new ArgumentNullException(nameof(dispatchPipeline)); - this.orchestrationService = orchestrationService; this.dispatcher = new WorkItemDispatcher( "TaskOrchestrationDispatcher", item => item == null ? string.Empty : item.InstanceId, @@ -140,7 +134,7 @@ protected async Task OnProcessWorkItemAsync(TaskOrchestrationWorkItem workItem) "Executing user orchestration: {0}", DataConverter.Serialize(runtimeState.GetOrchestrationRuntimeStateDump(), true)); - IList decisions = ExecuteOrchestration(runtimeState).ToList(); + IList decisions = (await ExecuteOrchestrationAsync(runtimeState)).ToList(); TraceHelper.TraceInstance(TraceEventType.Information, runtimeState.OrchestrationInstance, @@ -269,7 +263,7 @@ await this.orchestrationService.CompleteTaskOrchestrationWorkItemAsync( instanceState); } - internal virtual IEnumerable ExecuteOrchestration(OrchestrationRuntimeState runtimeState) + async Task> ExecuteOrchestrationAsync(OrchestrationRuntimeState runtimeState) { TaskOrchestration taskOrchestration = objectManager.GetObject(runtimeState.Name, runtimeState.Version); if (taskOrchestration == null) @@ -278,8 +272,20 @@ internal virtual IEnumerable ExecuteOrchestration(Orchestrat new TypeMissingException($"Orchestration not found: ({runtimeState.Name}, {runtimeState.Version})")); } - var taskOrchestrationExecutor = new TaskOrchestrationExecutor(runtimeState, taskOrchestration); - IEnumerable decisions = taskOrchestrationExecutor.Execute(); + var dispatchContext = new DispatchMiddlewareContext(); + dispatchContext.SetProperty(runtimeState.OrchestrationInstance); + dispatchContext.SetProperty(taskOrchestration); + dispatchContext.SetProperty(runtimeState); + + IEnumerable decisions = null; + await this.dispatchPipeline.RunAsync(dispatchContext, _ => + { + var taskOrchestrationExecutor = new TaskOrchestrationExecutor(runtimeState, taskOrchestration); + decisions = taskOrchestrationExecutor.Execute(); + + return Task.FromResult(0); + }); + return decisions; } diff --git a/test/DurableTask.Core.Tests/DurableTask.Core.Tests.csproj b/test/DurableTask.Core.Tests/DurableTask.Core.Tests.csproj index 8951ed702..25ffd6fda 100644 --- a/test/DurableTask.Core.Tests/DurableTask.Core.Tests.csproj +++ b/test/DurableTask.Core.Tests/DurableTask.Core.Tests.csproj @@ -14,7 +14,9 @@ + +