Skip to content

Commit

Permalink
try to handle the type and assembly name mangling
Browse files Browse the repository at this point in the history
  • Loading branch information
adamsitnik committed May 20, 2024
1 parent 00608e4 commit 6354a91
Show file tree
Hide file tree
Showing 10 changed files with 59 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,29 @@ namespace System.Runtime.Serialization.BinaryFormat;
/// </see>
/// </para>
/// </remarks>
[DebuggerDisplay("{Name}")]
[DebuggerDisplay("{RawName}")]
internal sealed class ClassInfo
{
private ClassInfo(int objectId, TypeName name, Dictionary<string, int> memberNames)
private ClassInfo(int objectId, string name, Dictionary<string, int> memberNames, PayloadOptions payloadOptions)
{
ObjectId = objectId;
Name = name;
RawName = name;
MemberNames = memberNames;
PayloadOptions = payloadOptions;
}

internal int ObjectId { get; }

internal TypeName Name { get; }
internal string RawName { get; }

internal Dictionary<string, int> MemberNames { get; }

internal PayloadOptions PayloadOptions { get; }

internal static ClassInfo Parse(BinaryReader reader, PayloadOptions payloadOptions)
{
int objectId = reader.ReadInt32();
TypeName typeName = reader.ReadTypeName(payloadOptions);
string typeName = reader.ReadString();
int memberCount = reader.ReadInt32();

// The attackers could create an input with MANY member names.
Expand All @@ -55,6 +58,37 @@ internal static ClassInfo Parse(BinaryReader reader, PayloadOptions payloadOptio
memberNames.Add(reader.ReadString(), i);
}

return new(objectId, typeName, memberNames);
return new(objectId, typeName, memberNames, payloadOptions);
}

internal TypeName GetTypeNameEvenIfMangled(string libraryName)
{
if (TypeName.TryParse(RawName.AsSpan(), out TypeName? typeName, PayloadOptions.TypeNameParseOptions))
{
if (typeName.AssemblyName is not null)
{
throw new SerializationException("Type names must not contain assembly names");
}
}
else if (!PayloadOptions.SupportMangledNames)
{
throw new SerializationException($"Invalid type name: '{RawName}'");
}

// adsitnik: use array pool to avoid allocations (if it turns out to be the right direction)
string assemblyQualifiedName = $"{RawName}, {libraryName}";

if (!TypeName.TryParse(assemblyQualifiedName.AsSpan(), out typeName, PayloadOptions.TypeNameParseOptions))
{
throw new SerializationException($"Invalid type name: '{RawName}' or library name: '{libraryName}'");
}

if (typeName.AssemblyName is null)
{
typeName = typeName.WithAssemblyName(FormatterServices.CoreLibRawName);
}

Debug.Assert(typeName.AssemblyName is not null);
return typeName;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ internal bool IsElementType(Type typeElement, RecordMap recordMap)

BinaryLibraryRecord libraryRecord = (BinaryLibraryRecord)recordMap[typeInfo.LibraryId];
string assemblyName = FormatterServices.GetAssemblyNameIncludingTypeForwards(typeElement);
return assemblyName == libraryRecord.LibraryName.FullName;
return assemblyName == libraryRecord.LibraryName;
default:
throw new NotSupportedException();
}
Expand Down Expand Up @@ -252,8 +252,8 @@ internal TypeName GetElementTypeName(RecordMap recordMap)
return ((TypeName)additionalInfo!).WithAssemblyName(FormatterServices.CoreLibAssemblyName.FullName);
case BinaryType.Class:
ClassTypeInfo typeInfo = (ClassTypeInfo)additionalInfo!;
AssemblyNameInfo libraryName = ((BinaryLibraryRecord)recordMap[typeInfo.LibraryId]).LibraryName;
return typeInfo.TypeName.WithAssemblyName(libraryName.FullName);
string libraryName = ((BinaryLibraryRecord)recordMap[typeInfo.LibraryId]).LibraryName;
return typeInfo.TypeName.WithAssemblyName(libraryName);
default:
throw new NotSupportedException();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,6 @@ sealed class PayloadOptions
public PayloadOptions() { }

public TypeNameParseOptions? TypeNameParseOptions { get; set; }

public bool SupportMangledNames { get; set; } = true; // adsitnik set to false and propagate
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,18 @@ namespace System.Runtime.Serialization.BinaryFormat;
/// </remarks>
internal sealed class BinaryLibraryRecord : SerializationRecord
{
private BinaryLibraryRecord(int libraryId, AssemblyNameInfo libraryName)
private BinaryLibraryRecord(int libraryId, string libraryName)
{
ObjectId = libraryId;
LibraryName = libraryName;
}

public override RecordType RecordType => RecordType.BinaryLibrary;

internal AssemblyNameInfo LibraryName { get; }
internal string LibraryName { get; }

public override int ObjectId { get; }

internal static BinaryLibraryRecord Parse(BinaryReader reader)
=> new(reader.ReadInt32(), reader.ReadLibraryName());
=> new(reader.ReadInt32(), reader.ReadString());
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ private protected ClassRecord(ClassInfo classInfo)
MemberValues = [];
}

public TypeName TypeName => _typeName ??= ClassInfo.Name.WithAssemblyName(LibraryName.FullName);
public TypeName TypeName => _typeName ??= ClassInfo.GetTypeNameEvenIfMangled(LibraryName);

internal abstract AssemblyNameInfo LibraryName { get; }
internal abstract string LibraryName { get; }

// Currently we don't expose raw values, so we are not preserving the order here.
public IEnumerable<string> MemberNames => ClassInfo.MemberNames.Keys;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ private ClassWithIdRecord(int objectId, ClassRecord metadataClass) : base(metada

public override RecordType RecordType => RecordType.ClassWithId;

internal override AssemblyNameInfo LibraryName => MetadataClass.LibraryName;
internal override string LibraryName => MetadataClass.LibraryName;

public override int ObjectId { get; }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ private ClassWithMembersAndTypesRecord(ClassInfo classInfo, BinaryLibraryRecord

public override RecordType RecordType => RecordType.ClassWithMembersAndTypes;

internal override AssemblyNameInfo LibraryName => Library.LibraryName;
internal override string LibraryName => Library.LibraryName;

internal BinaryLibraryRecord Library { get; }

Expand All @@ -36,8 +36,8 @@ private ClassWithMembersAndTypesRecord(ClassInfo classInfo, BinaryLibraryRecord
internal override int ExpectedValuesCount => MemberTypeInfo.Infos.Count;

public override bool IsTypeNameMatching(Type type)
=> FormatterServices.GetTypeFullNameIncludingTypeForwards(type) == ClassInfo.Name.FullName
&& FormatterServices.GetAssemblyNameIncludingTypeForwards(type) == Library.LibraryName.FullName;
=> FormatterServices.GetTypeFullNameIncludingTypeForwards(type) == TypeName.FullName
&& FormatterServices.GetAssemblyNameIncludingTypeForwards(type) == Library.LibraryName;

internal static ClassWithMembersAndTypesRecord Parse(BinaryReader reader, RecordMap recordMap, PayloadOptions options)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@ private SystemClassWithMembersAndTypesRecord(ClassInfo classInfo, MemberTypeInfo

public override RecordType RecordType => RecordType.SystemClassWithMembersAndTypes;

internal override AssemblyNameInfo LibraryName => FormatterServices.CoreLibAssemblyName;
internal override string LibraryName => FormatterServices.CoreLibRawName;

internal MemberTypeInfo MemberTypeInfo { get; }

internal override int ExpectedValuesCount => MemberTypeInfo.Infos.Count;

public override bool IsTypeNameMatching(Type type)
=> type.Assembly == typeof(object).Assembly
&& FormatterServices.GetTypeFullNameIncludingTypeForwards(type) == ClassInfo.Name.FullName;
&& FormatterServices.GetTypeFullNameIncludingTypeForwards(type) == TypeName.FullName;

internal static SystemClassWithMembersAndTypesRecord Parse(BinaryReader reader, PayloadOptions options)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,4 @@ internal static TypeName ReadTypeName(this BinaryReader binaryReader, PayloadOpt

return typeName;
}

internal static AssemblyNameInfo ReadLibraryName(this BinaryReader binaryReader)
{
string name = binaryReader.ReadString();
if (!AssemblyNameInfo.TryParse(name.AsSpan(), out AssemblyNameInfo? libraryName))
{
throw new SerializationException($"Invalid library name: '{name}'");
}

return libraryName;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@ namespace System.Runtime.Serialization.BinaryFormat;
internal static class FormatterServices
{
private static AssemblyNameInfo? s_coreLibAssemblyName;
private static string? s_coreLibRawName;

internal static AssemblyNameInfo CoreLibAssemblyName => s_coreLibAssemblyName ??= AssemblyNameInfo.Parse(GetAssemblyNameIncludingTypeForwards(typeof(object)).AsSpan());

internal static string CoreLibRawName => s_coreLibRawName ??= CoreLibAssemblyName.FullName;

internal static string GetAssemblyNameIncludingTypeForwards(Type type)
{
// Special case types like arrays
Expand Down

0 comments on commit 6354a91

Please sign in to comment.