Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added an option to disable dynamic assembly loading #6280

Merged
merged 3 commits into from
Feb 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
namespace NServiceBus.AcceptanceTests.Serialization
{
using System;
using System.Linq;
using System.Threading.Tasks;
using AcceptanceTesting;
using EndpointTemplates;
using NServiceBus.Pipeline;
using NUnit.Framework;

public class When_dynamic_loading_is_disabled : NServiceBusAcceptanceTest
{
[Test]
public async Task Should_not_load_type_dynamically()
{
var context = await Scenario.Define<Context>()
.WithEndpoint<ReceivingEndpoint>(e => e
.DoNotFailOnErrorMessages()
.When(session => session.SendLocal(new Message()))
)
.Done(c => c.MessageReceived)
.Run();

Assert.AreEqual(1, context.FailedMessages.Single().Value.Count);
Exception exception = context.FailedMessages.Single().Value.Single().Exception;
Assert.IsInstanceOf<MessageDeserializationException>(exception);
Assert.AreEqual($"Could not determine the message type from the '{Headers.EnclosedMessageTypes}' header and message type inference from the message body has been disabled. Ensure the header is set or enable message type inference.", exception.InnerException.Message);
}

class Context : ScenarioContext
{
public bool MessageReceived { get; set; }
}

class ReceivingEndpoint : EndpointConfigurationBuilder
{
public ReceivingEndpoint()
{
EndpointSetup<DefaultServer>(cfg =>
{
cfg.Pipeline.Register(typeof(PatchEnclosedMessageTypeHeader), "Patches the EnclosedMessageTypeHeader to contain a type that requires Type.GetType to be invoked.");
var serializerSettings = cfg.UseSerialization<XmlSerializer>();
serializerSettings.DisableDynamicTypeLoading();
serializerSettings.DisableMessageTypeInference(); // just throw when we can't find the message type
}).ExcludeType<PatchMessage>();
}

class PatchEnclosedMessageTypeHeader : Behavior<IIncomingPhysicalMessageContext>
{
Context testContext;

public PatchEnclosedMessageTypeHeader(Context testContext) => this.testContext = testContext;

public override Task Invoke(IIncomingPhysicalMessageContext context, Func<Task> next)
{
testContext.MessageReceived = true;

context.Message.Headers[Headers.EnclosedMessageTypes] = typeof(PatchMessage).AssemblyQualifiedName;

return next();
}
}
}

public class Message : IMessage
{
}

public class PatchMessage : IMessage
{
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1005,6 +1005,8 @@ namespace NServiceBus
}
public class static SerializationExtensionsExtensions
{
public static void DisableDynamicTypeLoading<T>(this NServiceBus.Serialization.SerializationExtensions<T> config)
where T : NServiceBus.Serialization.SerializationDefinition { }
public static void DisableMessageTypeInference<T>(this NServiceBus.Serialization.SerializationExtensions<T> config)
where T : NServiceBus.Serialization.SerializationDefinition { }
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1005,6 +1005,8 @@ namespace NServiceBus
}
public class static SerializationExtensionsExtensions
{
public static void DisableDynamicTypeLoading<T>(this NServiceBus.Serialization.SerializationExtensions<T> config)
where T : NServiceBus.Serialization.SerializationDefinition { }
public static void DisableMessageTypeInference<T>(this NServiceBus.Serialization.SerializationExtensions<T> config)
where T : NServiceBus.Serialization.SerializationDefinition { }
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ public class SerializeMessageConnectorTests
[Test]
public async Task Should_set_content_type_header()
{
var registry = new MessageMetadataRegistry(new Conventions().IsMessageType);
var registry = new MessageMetadataRegistry(new Conventions().IsMessageType, true);

registry.RegisterMessageTypesFoundIn(new List<Type>
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ static IOutgoingSendContext CreateContext(SendOptions options = null, object mes

static SendConnector InitializeBehavior(FakeRouter router = null)
{
var metadataRegistry = new MessageMetadataRegistry(new Conventions().IsMessageType);
var metadataRegistry = new MessageMetadataRegistry(new Conventions().IsMessageType, true);
metadataRegistry.RegisterMessageTypesFoundIn(new List<Type>
{
typeof(MyMessage),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ public void LoggerSetup()
[SetUp]
public void Setup()
{
metadataRegistry = new MessageMetadataRegistry(_ => true);
metadataRegistry = new MessageMetadataRegistry(_ => true, true);
endpointInstances = new EndpointInstances();
subscriptionStorage = new FakeSubscriptionStorage();
router = new UnicastPublishRouter(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,25 @@ public class When_getting_message_definition
[Test]
public void Should_throw_an_exception_for_a_unmapped_type()
{
var defaultMessageRegistry = new MessageMetadataRegistry(_ => false);
var defaultMessageRegistry = new MessageMetadataRegistry(_ => false, true);
Assert.Throws<Exception>(() => defaultMessageRegistry.GetMessageMetadata(typeof(int)));
}

[Test]
public void Should_return_null_when_resolving_unknown_type_from_type_identifier()
{
var registry = new MessageMetadataRegistry(t => true, true);

var metadata = registry.GetMessageMetadata(
"SomeNamespace.SomeType, SomeAssemblyName, Version=81.0.0.0, Culture=neutral, PublicKeyToken=null");

Assert.IsNull(metadata);
}

[Test]
public void Should_return_metadata_for_a_mapped_type()
{
var defaultMessageRegistry = new MessageMetadataRegistry(type => type == typeof(int));
var defaultMessageRegistry = new MessageMetadataRegistry(type => type == typeof(int), true);
defaultMessageRegistry.RegisterMessageTypesFoundIn(new List<Type> { typeof(int) });

var messageMetadata = defaultMessageRegistry.GetMessageMetadata(typeof(int));
Expand All @@ -35,7 +46,7 @@ public void Should_return_metadata_for_a_mapped_type()
[Test]
public void Should_return_the_correct_parent_hierarchy()
{
var defaultMessageRegistry = new MessageMetadataRegistry(new Conventions().IsMessageType);
var defaultMessageRegistry = new MessageMetadataRegistry(new Conventions().IsMessageType, true);

defaultMessageRegistry.RegisterMessageTypesFoundIn(new List<Type> { typeof(MyEvent) });
var messageMetadata = defaultMessageRegistry.GetMessageMetadata(typeof(MyEvent));
Expand All @@ -55,14 +66,47 @@ public void Should_return_the_correct_parent_hierarchy()
[TestCase("NServiceBus.Unicast.Tests.DefaultMessageRegistryTests+When_getting_message_definition+MyEvent")]
public void Should_match_types_from_a_different_assembly(string typeName)
{
var defaultMessageRegistry = new MessageMetadataRegistry(new Conventions().IsMessageType);
var defaultMessageRegistry = new MessageMetadataRegistry(new Conventions().IsMessageType, true);
defaultMessageRegistry.RegisterMessageTypesFoundIn(new List<Type> { typeof(MyEvent) });

var messageMetadata = defaultMessageRegistry.GetMessageMetadata(typeName);

Assert.AreEqual(typeof(MyEvent), messageMetadata.MessageHierarchy.ToList()[0]);
}

[Test]
public void Should_not_match_same_type_names_with_different_namespace()
{
var defaultMessageRegistry = new MessageMetadataRegistry(new Conventions().IsMessageType, true);
defaultMessageRegistry.RegisterMessageTypesFoundIn(new List<Type> { typeof(MyEvent) });

string typeIdentifier = typeof(MyEvent).AssemblyQualifiedName.Replace(typeof(MyEvent).FullName,
$"SomeNamespace.{nameof(MyEvent)}");
var messageMetadata = defaultMessageRegistry.GetMessageMetadata(typeIdentifier);

Assert.IsNull(messageMetadata);
}

[Test]
public void Should_resolve_uninitialized_types_from_loaded_assemblies()
{
var registry = new MessageMetadataRegistry(t => true, true);

var metadata = registry.GetMessageMetadata(typeof(EndpointConfiguration).AssemblyQualifiedName);

Assert.AreEqual(typeof(EndpointConfiguration), metadata.MessageType);
}

[Test]
public void Should_not_resolve_uninitialized_types_from_assembly_when_prohibiting_dynamic_typeloading()
{
var registry = new MessageMetadataRegistry(t => true, false);

var metadata = registry.GetMessageMetadata(typeof(EndpointConfiguration).AssemblyQualifiedName);

Assert.IsNull(metadata);
}

class MyEvent : ConcreteParent1, IInterfaceParent1
{

Expand Down
3 changes: 2 additions & 1 deletion src/NServiceBus.Core/EndpointCreator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,10 @@ void Initialize()
);
}


void ConfigureMessageTypes()
{
var messageMetadataRegistry = new MessageMetadataRegistry(conventions.IsMessageType);
var messageMetadataRegistry = new MessageMetadataRegistry(conventions.IsMessageType, settings.IsDynamicTypeLoadingEnabled());

messageMetadataRegistry.RegisterMessageTypesFoundIn(settings.GetAvailableTypes());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,22 @@ public static void DisableMessageTypeInference<T>(this SerializationExtensions<T
config.EndpointConfigurationSettings.Set(DisableMessageTypeInferenceKey, true);
}

/// <summary>
/// Disables dynamic type loading via <see cref="System.Type.GetType(string)"/> to prevent loading of assemblies for types passed in message header `NServiceBus.EnclosedMessageTypes` to only allow message types during deserialization that were explicitly loaded.
/// </summary>
public static void DisableDynamicTypeLoading<T>(this SerializationExtensions<T> config) where T : SerializationDefinition
{
Guard.AgainstNull(nameof(config), config);
config.EndpointConfigurationSettings.Set(DisableDynamicTypeLoadingKey, true);
}

internal static bool IsDynamicTypeLoadingEnabled(this ReadOnlySettings endpointConfigurationSettings) =>
!endpointConfigurationSettings.GetOrDefault<bool>(DisableDynamicTypeLoadingKey);

internal static bool IsMessageTypeInferenceEnabled(this ReadOnlySettings endpointConfigurationSettings) =>
!endpointConfigurationSettings.GetOrDefault<bool>(DisableMessageTypeInferenceKey);

const string DisableMessageTypeInferenceKey = "NServiceBus.Serialization.DisableMessageTypeInference";
const string DisableDynamicTypeLoadingKey = "NServiceBus.Serialization.DisableDynamicTypeLoading";
}
}
4 changes: 3 additions & 1 deletion src/NServiceBus.Core/Serialization/SerializationFeature.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ protected internal sealed override void Setup(FeatureConfigurationContext contex
}

var allowMessageTypeInference = settings.IsMessageTypeInferenceEnabled();
var allowDynamicTypeLoading = settings.IsDynamicTypeLoadingEnabled();
var resolver = new MessageDeserializerResolver(defaultSerializer, additionalDeserializers);
var logicalMessageFactory = new LogicalMessageFactory(messageMetadataRegistry, mapper);
context.Pipeline.Register("DeserializeLogicalMessagesConnector", new DeserializeMessageConnector(resolver, logicalMessageFactory, messageMetadataRegistry, mapper, allowMessageTypeInference), "Deserializes the physical message body into logical messages");
Expand All @@ -69,7 +70,8 @@ protected internal sealed override void Setup(FeatureConfigurationContext contex
defaultSerializer.ContentType
},
AdditionalDeserializers = additionalDeserializerDiagnostics,
AllowMessageTypeInference = allowMessageTypeInference
AllowMessageTypeInference = allowMessageTypeInference,
AllowDynamicTypeLoading = allowDynamicTypeLoading
});
}

Expand Down
48 changes: 30 additions & 18 deletions src/NServiceBus.Core/Unicast/Messages/MessageMetadataRegistry.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@
using System.Collections.Generic;
using System.Linq;
using Logging;

using NServiceBus;
/// <summary>
/// Cache of message metadata.
/// </summary>
public class MessageMetadataRegistry
{
internal MessageMetadataRegistry(Func<Type, bool> isMessageType)
internal MessageMetadataRegistry(Func<Type, bool> isMessageType, bool allowDynamicTypeLoading)
{
this.isMessageType = isMessageType;
this.allowDynamicTypeLoading = allowDynamicTypeLoading;
}

/// <summary>
Expand Down Expand Up @@ -48,26 +49,35 @@ public MessageMetadata GetMessageMetadata(string messageTypeIdentifier)
{
Guard.AgainstNullAndEmpty(nameof(messageTypeIdentifier), messageTypeIdentifier);

var messageType = GetType(messageTypeIdentifier);
var cacheHit = cachedTypes.TryGetValue(messageTypeIdentifier, out var messageType);

if (messageType == null)
if (!cacheHit)
{
Logger.DebugFormat("Message type: '{0}' could not be determined by a 'Type.GetType', scanning known messages for a match", messageTypeIdentifier);
messageType = GetType(messageTypeIdentifier);

foreach (var item in messages.Values)
if (messageType == null)
{
var messageTypeFullName = GetMessageTypeNameWithoutAssembly(messageTypeIdentifier);

if (item.MessageType.FullName == messageTypeIdentifier ||
item.MessageType.FullName == messageTypeFullName)
foreach (var item in messages.Values)
{
Logger.DebugFormat("Message type: '{0}' was mapped to '{1}'", messageTypeIdentifier, item.MessageType.AssemblyQualifiedName);
var messageTypeFullName = GetMessageTypeNameWithoutAssembly(messageTypeIdentifier);

cachedTypes[messageTypeIdentifier] = item.MessageType;
return item;
if (item.MessageType.FullName == messageTypeIdentifier ||
item.MessageType.FullName == messageTypeFullName)
{
Logger.DebugFormat("Message type: '{0}' was mapped to '{1}'", messageTypeIdentifier, item.MessageType.AssemblyQualifiedName);

cachedTypes[messageTypeIdentifier] = item.MessageType;
return item;
}
}
Logger.DebugFormat("Message type: '{0}' No match on known messages", messageTypeIdentifier);
}

cachedTypes[messageTypeIdentifier] = messageType;
}

if (messageType == null)
{
return null;
}

Expand Down Expand Up @@ -98,13 +108,12 @@ string GetMessageTypeNameWithoutAssembly(string messageTypeIdentifier)

Type GetType(string messageTypeIdentifier)
{
if (!cachedTypes.TryGetValue(messageTypeIdentifier, out var type))
if (allowDynamicTypeLoading)
{
type = Type.GetType(messageTypeIdentifier, false);
cachedTypes[messageTypeIdentifier] = type;
return Type.GetType(messageTypeIdentifier, false);
timbussmann marked this conversation as resolved.
Show resolved Hide resolved
}

return type;
Logger.Warn($"Unknown message type identifier '{messageTypeIdentifier}'. Dynamic type loading is disabled. Make sure the type is loaded before starting the endpoint or enable dynamic type loading.");
return null;
}

internal IEnumerable<MessageMetadata> GetAllMessages()
Expand Down Expand Up @@ -158,6 +167,7 @@ MessageMetadata RegisterMessageType(Type messageType)
}.Concat(parentMessages).ToArray());

messages[messageType.TypeHandle] = metadata;
cachedTypes.TryAdd(messageType.AssemblyQualifiedName, messageType);

return metadata;
}
Expand All @@ -168,6 +178,7 @@ static int PlaceInMessageHierarchy(Type type)
{
return type.GetInterfaces().Length;
}

var result = 0;

while (type.BaseType != null)
Expand Down Expand Up @@ -198,6 +209,7 @@ static IEnumerable<Type> GetParentTypes(Type type)
}

Func<Type, bool> isMessageType;
readonly bool allowDynamicTypeLoading;
ConcurrentDictionary<RuntimeTypeHandle, MessageMetadata> messages = new ConcurrentDictionary<RuntimeTypeHandle, MessageMetadata>();
ConcurrentDictionary<string, Type> cachedTypes = new ConcurrentDictionary<string, Type>();

Expand Down