Skip to content

Commit

Permalink
[release/8.0] Improve usage of Type.GetType when activating types i…
Browse files Browse the repository at this point in the history
…n data protection (#54762)

* Added test demonstrating behavior on TypeLoadException

* Catch exceptions when calling GetType and fall back to IActivator.CreateInstance<IXmlDecryptor>

* Move trimming suppression

* Include comment about how throwOnError: false can still throw

* Check type name for known value before getting type

* Remove unnecessary suppression

* Comment

* PR feedback

* Update src/DataProtection/DataProtection/test/Microsoft.AspNetCore.DataProtection.Tests/XmlEncryption/XmlEncryptionExtensionsTests.cs

---------

Co-authored-by: Kristian Hellang <kristian@identitystream.com>
Co-authored-by: James Newton-King <james@newtonking.com>
  • Loading branch information
3 people authored Apr 2, 2024
1 parent 53e662d commit 22bb1c3
Show file tree
Hide file tree
Showing 6 changed files with 173 additions and 14 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Diagnostics.CodeAnalysis;

namespace Microsoft.AspNetCore.DataProtection.Internal;

internal sealed class DefaultTypeNameResolver : ITypeNameResolver
{
public static readonly DefaultTypeNameResolver Instance = new();

private DefaultTypeNameResolver()
{
}

[UnconditionalSuppressMessage("Trimmer", "IL2057", Justification = "Type.GetType is only used to resolve statically known types that are referenced by DataProtection assembly.")]
public bool TryResolveType(string typeName, [NotNullWhen(true)] out Type? type)
{
try
{
// Some exceptions are thrown regardless of the value of throwOnError.
// For example, if the type is found but cannot be loaded,
// a System.TypeLoadException is thrown even if throwOnError is false.
type = Type.GetType(typeName, throwOnError: false);
return type != null;
}
catch
{
type = null;
return false;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Diagnostics.CodeAnalysis;

namespace Microsoft.AspNetCore.DataProtection.Internal;

internal interface ITypeNameResolver
{
bool TryResolveType(string typeName, [NotNullWhen(true)] out Type? type);
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ public sealed class XmlKeyManager : IKeyManager, IInternalXmlKeyManager
private const string RevokeAllKeysValue = "*";

private readonly IActivator _activator;
private readonly ITypeNameResolver _typeNameResolver;
private readonly AlgorithmConfiguration _authenticatedEncryptorConfiguration;
private readonly IKeyEscrowSink? _keyEscrowSink;
private readonly IInternalXmlKeyManager _internalKeyManager;
Expand Down Expand Up @@ -112,6 +113,8 @@ internal XmlKeyManager(
var escrowSinks = keyManagementOptions.Value.KeyEscrowSinks;
_keyEscrowSink = escrowSinks.Count > 0 ? new AggregateKeyEscrowSink(escrowSinks) : null;
_activator = activator;
// Note: ITypeNameResolver is only implemented on the activator in tests. In production, it's always DefaultTypeNameResolver.
_typeNameResolver = activator as ITypeNameResolver ?? DefaultTypeNameResolver.Instance;
TriggerAndResetCacheExpirationToken(suppressLogging: true);
_internalKeyManager = _internalKeyManager ?? this;
_encryptorFactories = keyManagementOptions.Value.AuthenticatedEncryptorFactories;
Expand Down Expand Up @@ -460,27 +463,27 @@ IAuthenticatedEncryptorDescriptor IInternalXmlKeyManager.DeserializeDescriptorFr
}
}

[UnconditionalSuppressMessage("Trimmer", "IL2057", Justification = "Type.GetType result is only useful with types that are referenced by DataProtection assembly.")]
private IAuthenticatedEncryptorDescriptorDeserializer CreateDeserializer(string descriptorDeserializerTypeName)
{
var resolvedTypeName = TypeForwardingActivator.TryForwardTypeName(descriptorDeserializerTypeName, out var forwardedTypeName)
// typeNameToMatch will be used for matching against known types but not passed to the activator.
// The activator will do its own forwarding.
var typeNameToMatch = TypeForwardingActivator.TryForwardTypeName(descriptorDeserializerTypeName, out var forwardedTypeName)
? forwardedTypeName
: descriptorDeserializerTypeName;
var type = Type.GetType(resolvedTypeName, throwOnError: false);

if (type == typeof(AuthenticatedEncryptorDescriptorDeserializer))
if (typeof(AuthenticatedEncryptorDescriptorDeserializer).MatchName(typeNameToMatch, _typeNameResolver))
{
return _activator.CreateInstance<AuthenticatedEncryptorDescriptorDeserializer>(descriptorDeserializerTypeName);
}
else if (type == typeof(CngCbcAuthenticatedEncryptorDescriptorDeserializer) && RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
else if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows) && typeof(CngCbcAuthenticatedEncryptorDescriptorDeserializer).MatchName(typeNameToMatch, _typeNameResolver))
{
return _activator.CreateInstance<CngCbcAuthenticatedEncryptorDescriptorDeserializer>(descriptorDeserializerTypeName);
}
else if (type == typeof(CngGcmAuthenticatedEncryptorDescriptorDeserializer) && RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
else if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows) && typeof(CngGcmAuthenticatedEncryptorDescriptorDeserializer).MatchName(typeNameToMatch, _typeNameResolver))
{
return _activator.CreateInstance<CngGcmAuthenticatedEncryptorDescriptorDeserializer>(descriptorDeserializerTypeName);
}
else if (type == typeof(ManagedAuthenticatedEncryptorDescriptorDeserializer))
else if (typeof(ManagedAuthenticatedEncryptorDescriptorDeserializer).MatchName(typeNameToMatch, _typeNameResolver))
{
return _activator.CreateInstance<ManagedAuthenticatedEncryptorDescriptorDeserializer>(descriptorDeserializerTypeName);
}
Expand Down
13 changes: 13 additions & 0 deletions src/DataProtection/DataProtection/src/TypeExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System;
using System.Diagnostics.CodeAnalysis;
using Microsoft.AspNetCore.DataProtection.Internal;

namespace Microsoft.AspNetCore.DataProtection;

Expand Down Expand Up @@ -39,4 +40,16 @@ public static Type GetTypeWithTrimFriendlyErrorMessage(string typeName)
throw new InvalidOperationException($"Unable to load type '{typeName}'. If the app is published with trimming then this type may have been trimmed. Ensure the type's assembly is excluded from trimming.", ex);
}
}

public static bool MatchName(this Type matchType, string resolvedTypeName, ITypeNameResolver typeNameResolver)
{
// Before attempting to resolve the name to a type, check if it starts with the full name of the type.
// Use StartsWith to ignore potential assembly version differences.
if (matchType.FullName != null && resolvedTypeName.StartsWith(matchType.FullName, StringComparison.Ordinal))
{
return typeNameResolver.TryResolveType(resolvedTypeName, out var resolvedType) && resolvedType == matchType;
}

return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -67,27 +67,30 @@ public static XElement DecryptElement(this XElement element, IActivator activato
return doc.Root!;
}

[UnconditionalSuppressMessage("Trimmer", "IL2057", Justification = "Type.GetType result is only useful with types that are referenced by DataProtection assembly.")]
private static IXmlDecryptor CreateDecryptor(IActivator activator, string decryptorTypeName)
{
var resolvedTypeName = TypeForwardingActivator.TryForwardTypeName(decryptorTypeName, out var forwardedTypeName)
// typeNameToMatch will be used for matching against known types but not passed to the activator.
// The activator will do its own forwarding.
var typeNameToMatch = TypeForwardingActivator.TryForwardTypeName(decryptorTypeName, out var forwardedTypeName)
? forwardedTypeName
: decryptorTypeName;
var type = Type.GetType(resolvedTypeName, throwOnError: false);

if (type == typeof(DpapiNGXmlDecryptor))
// Note: ITypeNameResolver is only implemented on the activator in tests. In production, it's always DefaultTypeNameResolver.
var typeNameResolver = activator as ITypeNameResolver ?? DefaultTypeNameResolver.Instance;

if (typeof(DpapiNGXmlDecryptor).MatchName(typeNameToMatch, typeNameResolver))
{
return activator.CreateInstance<DpapiNGXmlDecryptor>(decryptorTypeName);
}
else if (type == typeof(DpapiXmlDecryptor))
else if (typeof(DpapiXmlDecryptor).MatchName(typeNameToMatch, typeNameResolver))
{
return activator.CreateInstance<DpapiXmlDecryptor>(decryptorTypeName);
}
else if (type == typeof(EncryptedXmlDecryptor))
else if (typeof(EncryptedXmlDecryptor).MatchName(typeNameToMatch, typeNameResolver))
{
return activator.CreateInstance<EncryptedXmlDecryptor>(decryptorTypeName);
}
else if (type == typeof(NullXmlDecryptor))
else if (typeof(NullXmlDecryptor).MatchName(typeNameToMatch, typeNameResolver))
{
return activator.CreateInstance<NullXmlDecryptor>(decryptorTypeName);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,100 @@ public void DecryptElement_RootNodeRequiresDecryption_Success()
XmlAssert.Equal("<newNode />", retVal);
}

[Fact]
public void DecryptElement_CustomType_TypeNameResolverNotCalled()
{
// Arrange
var decryptorTypeName = typeof(MyXmlDecryptor).AssemblyQualifiedName;

var original = XElement.Parse(@$"
<x:encryptedSecret decryptorType='{decryptorTypeName}' xmlns:x='http://schemas.asp.net/2015/03/dataProtection'>
<node />
</x:encryptedSecret>");

var mockActivator = new Mock<IActivator>();
mockActivator.ReturnDecryptedElementGivenDecryptorTypeNameAndInput(decryptorTypeName, "<node />", "<newNode />");
var mockTypeNameResolver = mockActivator.As<ITypeNameResolver>();

var serviceCollection = new ServiceCollection();
serviceCollection.AddSingleton<IActivator>(mockActivator.Object);
var services = serviceCollection.BuildServiceProvider();
var activator = services.GetActivator();

// Act
var retVal = original.DecryptElement(activator);

// Assert
XmlAssert.Equal("<newNode />", retVal);
Type resolvedType;
mockTypeNameResolver.Verify(o => o.TryResolveType(It.IsAny<string>(), out resolvedType), Times.Never());
}

[Fact]
public void DecryptElement_KnownType_TypeNameResolverCalled()
{
// Arrange
var decryptorTypeName = typeof(NullXmlDecryptor).AssemblyQualifiedName;
TypeForwardingActivator.TryForwardTypeName(decryptorTypeName, out var forwardedTypeName);

var original = XElement.Parse(@$"
<x:encryptedSecret decryptorType='{decryptorTypeName}' xmlns:x='http://schemas.asp.net/2015/03/dataProtection'>
<node>
<value />
</node>
</x:encryptedSecret>");

var mockActivator = new Mock<IActivator>();
mockActivator.Setup(o => o.CreateInstance(typeof(NullXmlDecryptor), decryptorTypeName)).Returns(new NullXmlDecryptor());
var mockTypeNameResolver = mockActivator.As<ITypeNameResolver>();
var resolvedType = typeof(NullXmlDecryptor);
mockTypeNameResolver.Setup(mockTypeNameResolver => mockTypeNameResolver.TryResolveType(forwardedTypeName, out resolvedType)).Returns(true);

var serviceCollection = new ServiceCollection();
serviceCollection.AddSingleton<IActivator>(mockActivator.Object);
var services = serviceCollection.BuildServiceProvider();
var activator = services.GetActivator();

// Act
var retVal = original.DecryptElement(activator);

// Assert
XmlAssert.Equal("<value />", retVal);
mockTypeNameResolver.Verify(o => o.TryResolveType(It.IsAny<string>(), out resolvedType), Times.Once());
}

[Fact]
public void DecryptElement_KnownType_UnableToResolveType_Success()
{
// Arrange
var decryptorTypeName = typeof(NullXmlDecryptor).AssemblyQualifiedName;

var original = XElement.Parse(@$"
<x:encryptedSecret decryptorType='{decryptorTypeName}' xmlns:x='http://schemas.asp.net/2015/03/dataProtection'>
<node>
<value />
</node>
</x:encryptedSecret>");

var mockActivator = new Mock<IActivator>();
mockActivator.Setup(o => o.CreateInstance(typeof(IXmlDecryptor), decryptorTypeName)).Returns(new NullXmlDecryptor());
var mockTypeNameResolver = mockActivator.As<ITypeNameResolver>();
Type resolvedType = null;
mockTypeNameResolver.Setup(mockTypeNameResolver => mockTypeNameResolver.TryResolveType(It.IsAny<string>(), out resolvedType)).Returns(false);

var serviceCollection = new ServiceCollection();
serviceCollection.AddSingleton<IActivator>(mockActivator.Object);
var services = serviceCollection.BuildServiceProvider();
var activator = services.GetActivator();

// Act
var retVal = original.DecryptElement(activator);

// Assert
XmlAssert.Equal("<value />", retVal);
mockTypeNameResolver.Verify(o => o.TryResolveType(It.IsAny<string>(), out resolvedType), Times.Once());
}

[Fact]
public void DecryptElement_MultipleNodesRequireDecryption_AvoidsRecursion_Success()
{
Expand Down

0 comments on commit 22bb1c3

Please sign in to comment.