Skip to content

Commit

Permalink
Fix Enum field type bug found when underlying type is set from assemb…
Browse files Browse the repository at this point in the history
…ly loaded with MLC (#106375)

* Fix EnumBuilder.UnderlyingSystemType property

* Remove validation that failing when setting constants with core assembly type
  • Loading branch information
buyaa-n authored Aug 15, 2024
1 parent c88c31b commit 05053c4
Show file tree
Hide file tree
Showing 9 changed files with 44 additions and 110 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -303,4 +303,7 @@
<data name="InvalidOperation_UnmatchingSymScope" xml:space="preserve">
<value>Unmatching symbol scope.</value>
</data>
<data name="Argument_MustBeEnum" xml:space="preserve">
<value>Type provided must be an Enum.</value>
</data>
</root>
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ protected override void SetCustomAttributeCore(ConstructorInfo con, ReadOnlySpan

public override Type? ReflectedType => _typeBuilder.ReflectedType;

public override Type UnderlyingSystemType => GetEnumUnderlyingType();
public override Type UnderlyingSystemType => this;

public override Type GetEnumUnderlyingType() => _underlyingField.FieldType;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,78 +40,10 @@ internal FieldBuilderImpl(TypeBuilderImpl typeBuilder, string fieldName, Type ty
protected override void SetConstantCore(object? defaultValue)
{
_typeBuilder.ThrowIfCreated();
ValidateDefaultValueType(defaultValue, _fieldType);
_defaultValue = defaultValue;
_attributes |= FieldAttributes.HasDefault;
}

internal static void ValidateDefaultValueType(object? defaultValue, Type destinationType)
{
if (defaultValue == null)
{
// nullable value types can hold null value.
if (destinationType.IsValueType && !(destinationType.IsGenericType && destinationType.GetGenericTypeDefinition() == typeof(Nullable<>)))
{
throw new ArgumentException(SR.Argument_ConstantNull);
}
}
else
{
Type sourceType = defaultValue.GetType();
// We should allow setting a constant value on a ByRef parameter
if (destinationType.IsByRef)
{
destinationType = destinationType.GetElementType()!;
}

// Convert nullable types to their underlying type.
destinationType = Nullable.GetUnderlyingType(destinationType) ?? destinationType;

if (destinationType.IsEnum)
{
Type underlyingType;
if (destinationType is EnumBuilderImpl enumBldr)
{
underlyingType = enumBldr.GetEnumUnderlyingType();

if (sourceType != enumBldr._typeBuilder.UnderlyingSystemType &&
sourceType != underlyingType &&
// If the source type is an enum, should not throw when the underlying types match
sourceType.IsEnum &&
sourceType.GetEnumUnderlyingType() != underlyingType)
{
throw new ArgumentException(SR.Argument_ConstantDoesntMatch);
}
}
else if (destinationType is TypeBuilderImpl typeBldr)
{
underlyingType = typeBldr.UnderlyingSystemType;

if (underlyingType == null || (sourceType != typeBldr.UnderlyingSystemType && sourceType != underlyingType))
{
throw new ArgumentException(SR.Argument_ConstantDoesntMatch);
}
}
else
{
underlyingType = Enum.GetUnderlyingType(destinationType);

if (sourceType != destinationType && sourceType != underlyingType)
{
throw new ArgumentException(SR.Argument_ConstantDoesntMatch);
}
}
}
else
{
if (!destinationType.IsAssignableFrom(sourceType))
{
throw new ArgumentException(SR.Argument_ConstantDoesntMatch);
}
}
}
}

internal void SetData(byte[] data)
{
_rvaData = data;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,7 @@ public override void EmitCall(OpCode opcode, MethodInfo methodInfo, Type[]? opti
}

EmitOpcode(opcode);
UpdateStackSize(GetStackChange(opcode, methodInfo, optionalParameterTypes));
UpdateStackSize(GetStackChange(opcode, methodInfo, _moduleBuilder.GetTypeFromCoreAssembly(CoreTypeId.Void), optionalParameterTypes));
if (optionalParameterTypes == null || optionalParameterTypes.Length == 0)
{
WriteOrReserveToken(_moduleBuilder.TryGetMethodHandle(methodInfo), methodInfo);
Expand All @@ -613,12 +613,12 @@ public override void EmitCall(OpCode opcode, MethodInfo methodInfo, Type[]? opti
}
}

private static int GetStackChange(OpCode opcode, MethodInfo methodInfo, Type[]? optionalParameterTypes)
private static int GetStackChange(OpCode opcode, MethodInfo methodInfo, Type voidType, Type[]? optionalParameterTypes)
{
int stackChange = 0;

// Push the return value if there is one.
if (methodInfo.ReturnType != typeof(void))
if (methodInfo.ReturnType != voidType)
{
stackChange++;
}
Expand Down Expand Up @@ -665,7 +665,7 @@ public override void EmitCalli(OpCode opcode, CallingConventions callingConventi
}
}

int stackChange = GetStackChange(returnType, parameterTypes);
int stackChange = GetStackChange(returnType, _moduleBuilder.GetTypeFromCoreAssembly(CoreTypeId.Void), parameterTypes);

// Pop off VarArg arguments.
if (optionalParameterTypes != null)
Expand All @@ -685,17 +685,17 @@ public override void EmitCalli(OpCode opcode, CallingConventions callingConventi

public override void EmitCalli(OpCode opcode, CallingConvention unmanagedCallConv, Type? returnType, Type[]? parameterTypes)
{
int stackChange = GetStackChange(returnType, parameterTypes);
int stackChange = GetStackChange(returnType, _moduleBuilder.GetTypeFromCoreAssembly(CoreTypeId.Void), parameterTypes);
UpdateStackSize(stackChange);
Emit(OpCodes.Calli);
_il.Token(_moduleBuilder.GetSignatureToken(unmanagedCallConv, returnType, parameterTypes));
}

private static int GetStackChange(Type? returnType, Type[]? parameterTypes)
private static int GetStackChange(Type? returnType, Type voidType, Type[]? parameterTypes)
{
int stackChange = 0;
// If there is a non-void return type, push one.
if (returnType != typeof(void))
if (returnType != voidType)
{
stackChange++;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@ public ParameterBuilderImpl(MethodBuilderImpl methodBuilder, int sequence, Param

public override void SetConstant(object? defaultValue)
{
Type parameterType = _position == 0 ? _methodBuilder.ReturnType : _methodBuilder.ParameterTypes![_position - 1];
FieldBuilderImpl.ValidateDefaultValueType(defaultValue, parameterType);
_defaultValue = defaultValue;
_attributes |= ParameterAttributes.HasDefault;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ protected override void AddOtherMethodCore(MethodBuilder mdBuilder)
protected override void SetConstantCore(object? defaultValue)
{
_containingType.ThrowIfCreated();
FieldBuilderImpl.ValidateDefaultValueType(defaultValue, _propertyType);
_defaultValue = defaultValue;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,6 @@ protected override MethodBuilder DefineMethodCore(string name, MethodAttributes
{
ThrowIfCreated();


MethodBuilderImpl methodBuilder = new(name, attributes, callingConvention, returnType, returnTypeRequiredCustomModifiers,
returnTypeOptionalCustomModifiers, parameterTypes, parameterTypeRequiredCustomModifiers, parameterTypeOptionalCustomModifiers, _module, this);
_methodDefinitions.Add(methodBuilder);
Expand Down Expand Up @@ -616,23 +615,22 @@ public override Type GetGenericTypeDefinition()
public override string? Namespace => _namespace;
public override Assembly Assembly => _module.Assembly;
public override Module Module => _module;
public override Type UnderlyingSystemType
public override Type UnderlyingSystemType => this;

public override Type GetEnumUnderlyingType()
{
get
if (IsEnum)
{
if (IsEnum)
{
if (_enumUnderlyingType == null)
{
throw new InvalidOperationException(SR.InvalidOperation_NoUnderlyingTypeOnEnum);
}

return _enumUnderlyingType;
}
else
if (_enumUnderlyingType == null)
{
return this;
throw new InvalidOperationException(SR.InvalidOperation_NoUnderlyingTypeOnEnum);
}

return _enumUnderlyingType;
}
else
{
throw new ArgumentException(SR.Argument_MustBeEnum);
}
}
public override bool IsSZArray => false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public static IEnumerable<object[]> DefineLiteral_TestData()
yield return new object[] { typeof(uint), (uint)1 };

yield return new object[] { typeof(int), 0 };
yield return new object[] { typeof(int), 1 };
yield return new object[] { typeof(int), Test.Second };

yield return new object[] { typeof(ulong), (ulong)0 };
yield return new object[] { typeof(ulong), (ulong)1 };
Expand Down Expand Up @@ -100,7 +100,7 @@ public void CreateEnumWithMlc()
PersistedAssemblyBuilder ab = new PersistedAssemblyBuilder(PopulateAssemblyName(), mlc.CoreAssembly);
ModuleBuilder mb = ab.DefineDynamicModule("My Module");
Type intType = mlc.CoreAssembly.GetType("System.Int32");
EnumBuilder enumBuilder = mb.DefineEnum("TestEnum", TypeAttributes.Public, typeof(int));
EnumBuilder enumBuilder = mb.DefineEnum("TestEnum", TypeAttributes.Public, intType);
FieldBuilder field = enumBuilder.DefineLiteral("Default", 0);

enumBuilder.CreateTypeInfo();
Expand All @@ -118,7 +118,7 @@ public void CreateEnumWithMlc()

FieldInfo testField = createdEnum.GetField("Default");
Assert.Equal(createdEnum, testField.FieldType);
Assert.Equal(typeof(int), enumBuilder.GetEnumUnderlyingType());
Assert.Equal(intType, enumBuilder.GetEnumUnderlyingType());
Assert.Equal(FieldAttributes.Public | FieldAttributes.Static | FieldAttributes.Literal | FieldAttributes.HasDefault, testField.Attributes);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,24 @@ public void SetConstantVariousValues(Type returnType, object defaultValue)
Assert.Equal(defaultValue, property.GetConstantValue());
}

[Theory]
[MemberData(nameof(SetConstant_TestData))]
public void SetConstantVariousValuesMlcCoreAssembly(Type returnType, object defaultValue)
{
using (MetadataLoadContext mlc = new MetadataLoadContext(new CoreMetadataAssemblyResolver()))
{
PersistedAssemblyBuilder ab = new PersistedAssemblyBuilder(new AssemblyName("MyDynamicAssembly"), mlc.CoreAssembly);
ModuleBuilder mb = ab.DefineDynamicModule("My Module");
Type returnTypeFromCore = returnType != typeof(PropertyBuilderTest11.Colors) ? mlc.CoreAssembly.GetType(returnType.FullName, true) : returnType;
TypeBuilder type = mb.DefineType("MyType", TypeAttributes.Public);

PropertyBuilder property = type.DefineProperty("TestProperty", PropertyAttributes.HasDefault, returnTypeFromCore, null);
property.SetConstant(defaultValue);

Assert.Equal(defaultValue, property.GetConstantValue());
}
}

[Fact]
public void SetCustomAttribute_ConstructorInfo_ByteArray_NullConstructorInfo_ThrowsArgumentNullException()
{
Expand Down Expand Up @@ -194,7 +212,6 @@ public void Set_WhenTypeAlreadyCreated_ThrowsInvalidOperationException()
MethodAttributes getMethodAttributes = MethodAttributes.Public | MethodAttributes.SpecialName | MethodAttributes.HideBySig;
MethodBuilder method = type.DefineMethod("TestMethod", getMethodAttributes, typeof(int), null);
method.GetILGenerator().Emit(OpCodes.Ret);
AssertExtensions.Throws<ArgumentException>(() => property.SetConstant((decimal)10));
CustomAttributeBuilder customAttrBuilder = new CustomAttributeBuilder(typeof(IntPropertyAttribute).GetConstructor([typeof(int)]), [10]);
type.CreateType();

Expand All @@ -204,18 +221,5 @@ public void Set_WhenTypeAlreadyCreated_ThrowsInvalidOperationException()
Assert.Throws<InvalidOperationException>(() => property.SetConstant(1));
Assert.Throws<InvalidOperationException>(() => property.SetCustomAttribute(customAttrBuilder));
}

[Fact]
public void SetConstant_ValidationThrows()
{
AssemblySaveTools.PopulateAssemblyBuilderAndTypeBuilder(out TypeBuilder type);
FieldBuilder field = type.DefineField("TestField", typeof(int), FieldAttributes.Private);
PropertyBuilder property = type.DefineProperty("TestProperty", PropertyAttributes.HasDefault, typeof(int), null);

AssertExtensions.Throws<ArgumentException>(() => property.SetConstant((decimal)10));
AssertExtensions.Throws<ArgumentException>(() => property.SetConstant(null));
type.CreateType();
Assert.Throws<InvalidOperationException>(() => property.SetConstant(1));
}
}
}

0 comments on commit 05053c4

Please sign in to comment.