Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Honor user-provided Json IContractResolver while maintaining marshaling capabilities #783

Merged
merged 7 commits into from
Mar 22, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion src/StreamJsonRpc/JsonMessageFormatter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ namespace StreamJsonRpc
using System.IO.Pipelines;
using System.Linq;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Runtime.ExceptionServices;
using System.Runtime.Serialization;
using System.Text;
Expand Down Expand Up @@ -162,6 +163,16 @@ public class JsonMessageFormatter : IJsonRpcAsyncMessageTextFormatter, IJsonRpcF
/// </summary>
private JsonRpcMessage? deserializingMessage;

/// <summary>
/// Whether <see cref="EnforceFormatterIsInitialized"/> has been executed.
/// </summary>
private bool formatterInitializationChecked;

/// <summary>
/// Object used to lock when running <see cref="EnforceFormatterIsInitialized"/>.
/// </summary>
private object formatterInitializationLock = new();
matteo-prosperi marked this conversation as resolved.
Show resolved Hide resolved

/// <summary>
/// Initializes a new instance of the <see cref="JsonMessageFormatter"/> class
/// that uses JsonProgress (without the preamble) for its text encoding.
Expand Down Expand Up @@ -390,6 +401,8 @@ public JsonRpcMessage Deserialize(JToken json)
{
Requires.NotNull(json, nameof(json));

this.EnforceFormatterIsInitialized();

try
{
switch (this.ProtocolVersion.Major)
Expand Down Expand Up @@ -432,6 +445,8 @@ public JsonRpcMessage Deserialize(JToken json)
/// <returns>The JSON of the message.</returns>
public JToken Serialize(JsonRpcMessage message)
{
this.EnforceFormatterIsInitialized();

try
{
this.observedTransmittedRequestWithStringId |= message is JsonRpcRequest request && request.RequestId.String is not null;
Expand Down Expand Up @@ -528,6 +543,25 @@ private static object[] PartiallyParsePositionalArguments(JArray args)
return jtokenArray;
}

private void EnforceFormatterIsInitialized()
{
if (!this.formatterInitializationChecked)
{
lock (this.formatterInitializationLock)
{
if (!this.formatterInitializationChecked)
{
this.formatterInitializationChecked = true;
matteo-prosperi marked this conversation as resolved.
Show resolved Hide resolved
IContractResolver? originalContractResolver = this.JsonSerializer.ContractResolver;
if (originalContractResolver is not MarshalContractResolver)
{
this.JsonSerializer.ContractResolver = new MarshalContractResolver(this, originalContractResolver);
}
}
}
}
}

private void VerifyProtocolCompliance(bool condition, JToken message, string? explanation = null)
{
if (!condition)
Expand Down Expand Up @@ -1667,12 +1701,19 @@ public override void WriteJson(JsonWriter writer, Exception? value, JsonSerializ
private class MarshalContractResolver : DefaultContractResolver
matteo-prosperi marked this conversation as resolved.
Show resolved Hide resolved
{
private readonly JsonMessageFormatter formatter;
private readonly IContractResolver? userProvidedContractResolver;

public MarshalContractResolver(JsonMessageFormatter formatter)
{
this.formatter = formatter;
}

public MarshalContractResolver(JsonMessageFormatter formatter, IContractResolver? userProvidedContractResolver)
{
this.formatter = formatter;
this.userProvidedContractResolver = userProvidedContractResolver;
}

public override JsonContract ResolveContract(Type type)
{
if (this.formatter.TryGetMarshaledJsonConverter(type, out RpcMarshalableConverter? converter))
Expand All @@ -1684,7 +1725,10 @@ public override JsonContract ResolveContract(Type type)
};
}

JsonContract? result = base.ResolveContract(type);
JsonContract? result = this.userProvidedContractResolver is not null ?
this.userProvidedContractResolver.ResolveContract(type) :
base.ResolveContract(type);

switch (result)
{
case JsonObjectContract objectContract:
Expand Down
108 changes: 108 additions & 0 deletions test/StreamJsonRpc.Tests/JsonContractResolverTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System.Collections.Immutable;
using System.Diagnostics;
using Microsoft.VisualStudio.Threading;
using Nerdbank.Streams;
using Newtonsoft.Json.Serialization;
using StreamJsonRpc;
using Xunit;
using Xunit.Abstractions;

public partial class JsonContractResolverTest : TestBase
{
protected readonly Server server = new Server();
protected readonly JsonRpc serverRpc;
protected readonly JsonRpc clientRpc;
protected readonly IServer client;

private const string ExceptionMessage = "Some exception";

public JsonContractResolverTest(ITestOutputHelper logger)
: base(logger)
{
var pipes = FullDuplexStream.CreatePipePair();

this.client = JsonRpc.Attach<IServer>(new LengthHeaderMessageHandler(pipes.Item1, this.CreateFormatter()));
this.clientRpc = ((IJsonRpcClientProxy)this.client).JsonRpc;

this.serverRpc = new JsonRpc(new LengthHeaderMessageHandler(pipes.Item2, this.CreateFormatter()));
this.serverRpc.AddLocalRpcTarget(this.server);

this.serverRpc.TraceSource = new TraceSource("Server", SourceLevels.Verbose);
this.clientRpc.TraceSource = new TraceSource("Client", SourceLevels.Verbose);

this.serverRpc.TraceSource.Listeners.Add(new XunitTraceListener(this.Logger));
this.clientRpc.TraceSource.Listeners.Add(new XunitTraceListener(this.Logger));

this.serverRpc.StartListening();
}

protected interface IServer : IDisposable
{
Task PushCompleteAndReturn(IObserver<int> observer);
}

[Fact]
public async Task PushCompleteAndReturn()
{
var observer = new MockObserver<int>();
await Task.Run(() => this.client.PushCompleteAndReturn(observer)).WithCancellation(this.TimeoutToken);
await observer.Completion.WithCancellation(this.TimeoutToken);
ImmutableList<int> result = await observer.Completion;
}

protected IJsonRpcMessageFormatter CreateFormatter()
{
var formatter = new JsonMessageFormatter();
formatter.JsonSerializer.ContractResolver = new CamelCasePropertyNamesContractResolver();
matteo-prosperi marked this conversation as resolved.
Show resolved Hide resolved

return formatter;
}

protected class Server : IServer
{
public Task PushCompleteAndReturn(IObserver<int> observer)
{
for (int i = 1; i <= 3; i++)
{
observer.OnNext(i);
}

observer.OnCompleted();
return Task.CompletedTask;
}

void IDisposable.Dispose()
{
}
}

protected class MockObserver<T> : IObserver<T>
{
private readonly TaskCompletionSource<ImmutableList<T>> completed = new TaskCompletionSource<ImmutableList<T>>();

internal event EventHandler<T>? Next;

[System.Runtime.Serialization.IgnoreDataMember]
internal ImmutableList<T> ReceivedValues { get; private set; } = ImmutableList<T>.Empty;

[System.Runtime.Serialization.IgnoreDataMember]
matteo-prosperi marked this conversation as resolved.
Show resolved Hide resolved
internal Task<ImmutableList<T>> Completion => this.completed.Task;

internal AsyncAutoResetEvent ItemReceived { get; } = new AsyncAutoResetEvent();

public void OnCompleted() => this.completed.SetResult(this.ReceivedValues);

public void OnError(Exception error) => this.completed.SetException(error);

public void OnNext(T value)
{
Assert.False(this.completed.Task.IsCompleted);
this.ReceivedValues = this.ReceivedValues.Add(value);
this.Next?.Invoke(this, value);
this.ItemReceived.Set();
}
}
}