diff --git a/src/DtronixMessageQueue.Tests/DtronixMessageQueue.Tests.csproj b/src/DtronixMessageQueue.Tests/DtronixMessageQueue.Tests.csproj index 18b0963..9f100b0 100644 --- a/src/DtronixMessageQueue.Tests/DtronixMessageQueue.Tests.csproj +++ b/src/DtronixMessageQueue.Tests/DtronixMessageQueue.Tests.csproj @@ -80,7 +80,6 @@ - diff --git a/src/DtronixMessageQueue.Tests/Rpc/RpcClientTests.cs b/src/DtronixMessageQueue.Tests/Rpc/RpcClientTests.cs index 5c576b6..a8f0fa5 100644 --- a/src/DtronixMessageQueue.Tests/Rpc/RpcClientTests.cs +++ b/src/DtronixMessageQueue.Tests/Rpc/RpcClientTests.cs @@ -1,5 +1,7 @@ using System; using System.Diagnostics; +using System.Threading; +using System.Threading.Tasks; using DtronixMessageQueue.Tests.Rpc.Services.Server; using Xunit; using Xunit.Abstractions; @@ -11,6 +13,7 @@ public class RpcClientTests : RpcTestsBase { public RpcClientTests(ITestOutputHelper output) : base(output) { } + public class Test { public string TestStr { get; set; } public int Length { get; set; } @@ -57,7 +60,7 @@ public void Client_calls_proxy_method_sequential() { for (int i = 0; i < 10; i++) { added_int = service.Add(added_int, 1); } - + Output.WriteLine($"{stopwatch.ElapsedMilliseconds}"); TestStatus.Set(); }; @@ -69,25 +72,38 @@ public void Client_calls_proxy_method_sequential() { public void Client_calls_proxy_method_and_canceles() { Server.Connected += (sender, args) => { - args.Session.AddService(new CalculatorService()); + var service = new CalculatorService(); + args.Session.AddService(service); + + service.LongRunningTaskCanceled += (o, event_args) => { + TestStatus.Set(); + }; }; Client.Connected += (sender, args) => { args.Session.AddProxy(new CalculatorService()); var service = Client.Session.GetProxy(); - Stopwatch stopwatch = Stopwatch.StartNew(); - - int added_int = 0; - for (int i = 0; i < 10; i++) { - added_int = service.Add(added_int, 1); + var token_source = new CancellationTokenSource(); + + token_source.CancelAfter(200); + bool threw = false; + try { + service.LongRunningTask(1, 2, token_source.Token); + } catch (OperationCanceledException ex) { + threw = true; } - Output.WriteLine($"{stopwatch.ElapsedMilliseconds}"); - TestStatus.Set(); + if (threw != true) { + LastException = new Exception("Operation did not cancel."); + } + + }; StartAndWait(); } + + } } diff --git a/src/DtronixMessageQueue.Tests/Rpc/Services/Server/CalculatorService.cs b/src/DtronixMessageQueue.Tests/Rpc/Services/Server/CalculatorService.cs index 34924bc..4c58ce6 100644 --- a/src/DtronixMessageQueue.Tests/Rpc/Services/Server/CalculatorService.cs +++ b/src/DtronixMessageQueue.Tests/Rpc/Services/Server/CalculatorService.cs @@ -1,5 +1,6 @@ using System; using System.Threading; +using DtronixMessageQueue.Rpc; namespace DtronixMessageQueue.Tests.Rpc.Services.Server { public class CalculatorService : MarshalByRefObject, ICalculatorService { @@ -27,13 +28,22 @@ public int Divide(int number_1, int number_2) { public int LongRunningTask(int number_1, int number_2, CancellationToken token) { ManualResetEventSlim mre = new ManualResetEventSlim(); - mre.Wait(token); - - if (mre.IsSet == false) { + try { + mre.Wait(token); + } catch (Exception) { LongRunningTaskCanceled?.Invoke(this, EventArgs.Empty); + throw; } return number_1 / number_2; } } + + public interface ICalculatorService : IRemoteService { + int Add(int number_1, int number_2); + int Subtract(int number_1, int number_2); + int Multiply(int number_1, int number_2); + int Divide(int number_1, int number_2); + int LongRunningTask(int number_1, int number_2, CancellationToken token); + } } diff --git a/src/DtronixMessageQueue.Tests/Rpc/Services/Server/ICalculatorService.cs b/src/DtronixMessageQueue.Tests/Rpc/Services/Server/ICalculatorService.cs deleted file mode 100644 index 6a3749e..0000000 --- a/src/DtronixMessageQueue.Tests/Rpc/Services/Server/ICalculatorService.cs +++ /dev/null @@ -1,10 +0,0 @@ -using DtronixMessageQueue.Rpc; - -namespace DtronixMessageQueue.Tests.Rpc.Services.Server { - public interface ICalculatorService : IRemoteService { - int Add(int number_1, int number_2); - int Subtract(int number_1, int number_2); - int Multiply(int number_1, int number_2); - int Divide(int number_1, int number_2); - } -} diff --git a/src/DtronixMessageQueue/Rpc/RpcProxy.cs b/src/DtronixMessageQueue/Rpc/RpcProxy.cs index c5cf60b..b82fef4 100644 --- a/src/DtronixMessageQueue/Rpc/RpcProxy.cs +++ b/src/DtronixMessageQueue/Rpc/RpcProxy.cs @@ -19,7 +19,7 @@ class RpcProxy : RealProxy public RpcProxy(T decorated, RpcSession session) : base(typeof(T)) { this.decorated = decorated; - this.session = (TSession)session; + this.session = (TSession) session; } public override IMessage Invoke(IMessage msg) { @@ -48,22 +48,23 @@ public override IMessage Invoke(IMessage msg) { // Determine what kind of method we are calling. if (method_info.ReturnType == typeof(void)) { - store.MessageWriter.Write((byte)RpcMessageType.RpcCallNoReturn); + store.MessageWriter.Write((byte) RpcMessageType.RpcCallNoReturn); } else { - store.MessageWriter.Write((byte)RpcMessageType.RpcCall); + store.MessageWriter.Write((byte) RpcMessageType.RpcCall); - return_wait = session.CreateReturnCallWait(); + return_wait = session.CreateWaitOperation(); store.MessageWriter.Write(return_wait.Id); return_wait.Token = cancellation_token; } store.MessageWriter.Write(decorated.Name); store.MessageWriter.Write(method_call.MethodName); - store.MessageWriter.Write((byte)arguments.Length); + store.MessageWriter.Write((byte) arguments.Length); int field_number = 0; foreach (var arg in arguments) { - RuntimeTypeModel.Default.SerializeWithLengthPrefix(store.Stream, arg, arg.GetType(), PrefixStyle.Base128, field_number++); + RuntimeTypeModel.Default.SerializeWithLengthPrefix(store.Stream, arg, arg.GetType(), PrefixStyle.Base128, + field_number++); store.MessageWriter.Write(store.Stream.ToArray()); // Should always read the entire buffer in one go. @@ -78,21 +79,19 @@ public override IMessage Invoke(IMessage msg) { return new ReturnMessage(null, null, 0, method_call.LogicalCallContext, method_call); } - return_wait.ReturnResetEvent.Wait(session.Config.SendTimeout, return_wait.Token); + try { + return_wait.ReturnResetEvent.Wait(return_wait.Token); + } catch (OperationCanceledException) { + session.CancelWaitOperation(return_wait.Id); + // If the operation was canceled, cancel the wait on this end and notify the other end. + throw new OperationCanceledException("Wait handle was canceled while waiting for a response."); + } + if (return_wait.ReturnResetEvent.IsSet == false) { throw new TimeoutException("Wait handle timed out waiting for a response."); } - if (return_wait.Token.IsCancellationRequested) { - store.MessageWriter.Clear(); - store.MessageWriter.Write((byte)RpcMessageType.RpcCallCancellation); - store.MessageWriter.Write(return_wait.Id); - - session.Send(store.MessageWriter.ToMessage()); - - throw new OperationCanceledException("Wait handle was canceled while waiting for a response."); - } try { diff --git a/src/DtronixMessageQueue/Rpc/RpcSession.cs b/src/DtronixMessageQueue/Rpc/RpcSession.cs index 1af454c..58c9e45 100644 --- a/src/DtronixMessageQueue/Rpc/RpcSession.cs +++ b/src/DtronixMessageQueue/Rpc/RpcSession.cs @@ -80,6 +80,11 @@ public override void OnIncomingMessage(object sender, IncomingMessageEventArgs 0) { @@ -179,15 +184,16 @@ private void ProcessRpcCall(MqMessage message, RpcMessageType message_type) { } } - - + if (cancellation_token_param > 0) { + parameters[parameters.Length - 1] = cancellation_source.Token; + } object return_value; try { return_value = method_info.Invoke(service, parameters); } catch (Exception ex) { - if (rec_message_return_id != 0) { + if (rec_message_return_id != 0 && ex.InnerException?.GetType() != typeof(OperationCanceledException)) { SendRpcException(store, ex, rec_message_return_id); } return; @@ -243,7 +249,7 @@ private void SendRpcException(SerializationStore.Store store, Exception ex, usho } - public RpcOperationWait CreateReturnCallWait() { + public RpcOperationWait CreateWaitOperation() { var return_wait = new RpcOperationWait { ReturnResetEvent = new ManualResetEventSlim() }; @@ -262,6 +268,17 @@ public RpcOperationWait CreateReturnCallWait() { return return_wait; } + public void CancelWaitOperation(ushort id) { + RpcOperationWait call_wait; + outstanding_waits.TryRemove(id, out call_wait); + + var frame = new MqFrame(new byte[3], MqFrameType.Last, (MqSocketConfig) Config); + frame.Write(0, (byte)RpcMessageType.RpcCallCancellation); + frame.Write(1, id); + + Send(frame); + } + private void ProcessRpcReturn(MqMessage mq_message) { var store = Store.Get();