Skip to content

Commit

Permalink
PR feedback; Move into TempAssembly; Strong/Weak lookup tables; Test …
Browse files Browse the repository at this point in the history
…refinement
  • Loading branch information
StephenMolloy committed Sep 28, 2021
1 parent 55eb88b commit d9e96dc
Show file tree
Hide file tree
Showing 5 changed files with 257 additions and 212 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Runtime.CompilerServices;
using System.Runtime.Loader;

namespace System.Xml.Serialization
{
Expand Down Expand Up @@ -150,79 +151,82 @@ internal void InitAssemblyMethods(XmlMapping[] xmlMappings)
contract = null;
string? serializerName = null;

// check to see if we loading explicit pre-generated assembly
object[] attrs = type.GetCustomAttributes(typeof(System.Xml.Serialization.XmlSerializerAssemblyAttribute), false);
if (attrs.Length == 0)
using (AssemblyLoadContext.EnterContextualReflection(type.Assembly))
{
// Guess serializer name: if parent assembly signed use strong name
AssemblyName name = type.Assembly.GetName();
serializerName = Compiler.GetTempAssemblyName(name, defaultNamespace);
// use strong name
name.Name = serializerName;
name.CodeBase = null;
name.CultureInfo = CultureInfo.InvariantCulture;

try
{
serializer = Assembly.Load(name);
}
catch (Exception e)
{
if (e is OutOfMemoryException)
// check to see if we loading explicit pre-generated assembly
object[] attrs = type.GetCustomAttributes(typeof(System.Xml.Serialization.XmlSerializerAssemblyAttribute), false);
if (attrs.Length == 0)
{
// Guess serializer name: if parent assembly signed use strong name
AssemblyName name = type.Assembly.GetName();
serializerName = Compiler.GetTempAssemblyName(name, defaultNamespace);
// use strong name
name.Name = serializerName;
name.CodeBase = null;
name.CultureInfo = CultureInfo.InvariantCulture;

try
{
throw;
serializer = Assembly.Load(name);
}
}

serializer ??= LoadAssemblyByPath(type, serializerName);

if (serializer == null)
{
if (XmlSerializer.Mode == SerializationMode.PreGenOnly)
catch (Exception e)
{
throw new Exception(SR.Format(SR.FailLoadAssemblyUnderPregenMode, serializerName));
if (e is OutOfMemoryException)
{
throw;
}
}

return null;
}
serializer ??= LoadAssemblyByPath(type, serializerName);

if (!IsSerializerVersionMatch(serializer, type, defaultNamespace))
{
XmlSerializationEventSource.Log.XmlSerializerExpired(serializerName, type.FullName!);
return null;
}
}
else
{
System.Xml.Serialization.XmlSerializerAssemblyAttribute assemblyAttribute = (System.Xml.Serialization.XmlSerializerAssemblyAttribute)attrs[0];
if (assemblyAttribute.AssemblyName != null && assemblyAttribute.CodeBase != null)
throw new InvalidOperationException(SR.Format(SR.XmlPregenInvalidXmlSerializerAssemblyAttribute, "AssemblyName", "CodeBase"));
if (serializer == null)
{
if (XmlSerializer.Mode == SerializationMode.PreGenOnly)
{
throw new Exception(SR.Format(SR.FailLoadAssemblyUnderPregenMode, serializerName));
}

// found XmlSerializerAssemblyAttribute attribute, it should have all needed information to load the pre-generated serializer
if (assemblyAttribute.AssemblyName != null)
{
serializerName = assemblyAttribute.AssemblyName;
serializer = Assembly.Load(serializerName); // LoadWithPartialName just does this in .Net Core; changing the obsolete call.
}
else if (assemblyAttribute.CodeBase != null && assemblyAttribute.CodeBase.Length > 0)
{
serializerName = assemblyAttribute.CodeBase;
serializer = Assembly.LoadFrom(serializerName);
return null;
}

if (!IsSerializerVersionMatch(serializer, type, defaultNamespace))
{
XmlSerializationEventSource.Log.XmlSerializerExpired(serializerName, type.FullName!);
return null;
}
}
else
{
serializerName = type.Assembly.FullName;
serializer = type.Assembly;
}
if (serializer == null)
{
throw new FileNotFoundException(null, serializerName);
System.Xml.Serialization.XmlSerializerAssemblyAttribute assemblyAttribute = (System.Xml.Serialization.XmlSerializerAssemblyAttribute)attrs[0];
if (assemblyAttribute.AssemblyName != null && assemblyAttribute.CodeBase != null)
throw new InvalidOperationException(SR.Format(SR.XmlPregenInvalidXmlSerializerAssemblyAttribute, "AssemblyName", "CodeBase"));

// found XmlSerializerAssemblyAttribute attribute, it should have all needed information to load the pre-generated serializer
if (assemblyAttribute.AssemblyName != null)
{
serializerName = assemblyAttribute.AssemblyName;
serializer = Assembly.Load(serializerName); // LoadWithPartialName just does this in .Net Core; changing the obsolete call.
}
else if (assemblyAttribute.CodeBase != null && assemblyAttribute.CodeBase.Length > 0)
{
serializerName = assemblyAttribute.CodeBase;
serializer = Assembly.LoadFrom(serializerName);
}
else
{
serializerName = type.Assembly.FullName;
serializer = type.Assembly;
}
if (serializer == null)
{
throw new FileNotFoundException(null, serializerName);
}
}
Type contractType = GetTypeFromAssembly(serializer, "XmlSerializerContract");
contract = (XmlSerializerImplementation)Activator.CreateInstance(contractType)!;
if (contract.CanSerialize(type))
return serializer;
}
Type contractType = GetTypeFromAssembly(serializer, "XmlSerializerContract");
contract = (XmlSerializerImplementation)Activator.CreateInstance(contractType)!;
if (contract.CanSerialize(type))
return serializer;

return null;
}
Expand Down Expand Up @@ -452,79 +456,83 @@ internal static bool GenerateSerializerToStream(XmlMapping[] xmlMappings, Type?[
[RequiresUnreferencedCode("calls GenerateElement")]
internal static Assembly GenerateRefEmitAssembly(XmlMapping[] xmlMappings, Type?[]? types, string? defaultNamespace)
{
var mainType = (types != null && types.Length > 0) ? types[0] : null;
var scopeTable = new Dictionary<TypeScope, XmlMapping>();
foreach (XmlMapping mapping in xmlMappings)
scopeTable[mapping.Scope!] = mapping;
TypeScope[] scopes = new TypeScope[scopeTable.Keys.Count];
scopeTable.Keys.CopyTo(scopes, 0);

string assemblyName = "Microsoft.GeneratedCode";
AssemblyBuilder assemblyBuilder = CodeGenerator.CreateAssemblyBuilder(assemblyName);
// Add AssemblyVersion attribute to match parent assembly version
if (types != null && types.Length > 0 && types[0] != null)
{
ConstructorInfo AssemblyVersionAttribute_ctor = typeof(AssemblyVersionAttribute).GetConstructor(
new Type[] { typeof(string) }
)!;
string assemblyVersion = types[0]!.Assembly.GetName().Version!.ToString();
assemblyBuilder.SetCustomAttribute(new CustomAttributeBuilder(AssemblyVersionAttribute_ctor, new object[] { assemblyVersion }));
}
CodeIdentifiers classes = new CodeIdentifiers();
classes.AddUnique("XmlSerializationWriter", "XmlSerializationWriter");
classes.AddUnique("XmlSerializationReader", "XmlSerializationReader");
string? suffix = null;
if (types != null && types.Length == 1 && types[0] != null)
using (AssemblyLoadContext.EnterContextualReflection(mainType?.Assembly))
{
suffix = CodeIdentifier.MakeValid(types[0]!.Name);
if (types[0]!.IsArray)
string assemblyName = "Microsoft.GeneratedCode";
AssemblyBuilder assemblyBuilder = CodeGenerator.CreateAssemblyBuilder(assemblyName);
// Add AssemblyVersion attribute to match parent assembly version
if (mainType != null)
{
ConstructorInfo AssemblyVersionAttribute_ctor = typeof(AssemblyVersionAttribute).GetConstructor(
new Type[] { typeof(string) }
)!;
string assemblyVersion = mainType.Assembly.GetName().Version!.ToString();
assemblyBuilder.SetCustomAttribute(new CustomAttributeBuilder(AssemblyVersionAttribute_ctor, new object[] { assemblyVersion }));
}
CodeIdentifiers classes = new CodeIdentifiers();
classes.AddUnique("XmlSerializationWriter", "XmlSerializationWriter");
classes.AddUnique("XmlSerializationReader", "XmlSerializationReader");
string? suffix = null;
if (mainType != null)
{
suffix += "Array";
suffix = CodeIdentifier.MakeValid(mainType.Name);
if (mainType.IsArray)
{
suffix += "Array";
}
}
}

ModuleBuilder moduleBuilder = CodeGenerator.CreateModuleBuilder(assemblyBuilder, assemblyName);
ModuleBuilder moduleBuilder = CodeGenerator.CreateModuleBuilder(assemblyBuilder, assemblyName);

string writerClass = "XmlSerializationWriter" + suffix;
writerClass = classes.AddUnique(writerClass, writerClass);
XmlSerializationWriterILGen writerCodeGen = new XmlSerializationWriterILGen(scopes, "public", writerClass);
writerCodeGen.ModuleBuilder = moduleBuilder;
string writerClass = "XmlSerializationWriter" + suffix;
writerClass = classes.AddUnique(writerClass, writerClass);
XmlSerializationWriterILGen writerCodeGen = new XmlSerializationWriterILGen(scopes, "public", writerClass);
writerCodeGen.ModuleBuilder = moduleBuilder;

writerCodeGen.GenerateBegin();
string[] writeMethodNames = new string[xmlMappings.Length];
writerCodeGen.GenerateBegin();
string[] writeMethodNames = new string[xmlMappings.Length];

for (int i = 0; i < xmlMappings.Length; i++)
{
writeMethodNames[i] = writerCodeGen.GenerateElement(xmlMappings[i])!;
}
Type writerType = writerCodeGen.GenerateEnd();
for (int i = 0; i < xmlMappings.Length; i++)
{
writeMethodNames[i] = writerCodeGen.GenerateElement(xmlMappings[i])!;
}
Type writerType = writerCodeGen.GenerateEnd();

string readerClass = "XmlSerializationReader" + suffix;
readerClass = classes.AddUnique(readerClass, readerClass);
XmlSerializationReaderILGen readerCodeGen = new XmlSerializationReaderILGen(scopes, "public", readerClass);
string readerClass = "XmlSerializationReader" + suffix;
readerClass = classes.AddUnique(readerClass, readerClass);
XmlSerializationReaderILGen readerCodeGen = new XmlSerializationReaderILGen(scopes, "public", readerClass);

readerCodeGen.ModuleBuilder = moduleBuilder;
readerCodeGen.CreatedTypes.Add(writerType.Name, writerType);
readerCodeGen.ModuleBuilder = moduleBuilder;
readerCodeGen.CreatedTypes.Add(writerType.Name, writerType);

readerCodeGen.GenerateBegin();
string[] readMethodNames = new string[xmlMappings.Length];
for (int i = 0; i < xmlMappings.Length; i++)
{
readMethodNames[i] = readerCodeGen.GenerateElement(xmlMappings[i])!;
}
readerCodeGen.GenerateEnd(readMethodNames, xmlMappings, types!);
readerCodeGen.GenerateBegin();
string[] readMethodNames = new string[xmlMappings.Length];
for (int i = 0; i < xmlMappings.Length; i++)
{
readMethodNames[i] = readerCodeGen.GenerateElement(xmlMappings[i])!;
}
readerCodeGen.GenerateEnd(readMethodNames, xmlMappings, types!);

string baseSerializer = readerCodeGen.GenerateBaseSerializer("XmlSerializer1", readerClass, writerClass, classes);
var serializers = new Dictionary<string, string>();
for (int i = 0; i < xmlMappings.Length; i++)
{
if (!serializers.ContainsKey(xmlMappings[i].Key!))
string baseSerializer = readerCodeGen.GenerateBaseSerializer("XmlSerializer1", readerClass, writerClass, classes);
var serializers = new Dictionary<string, string>();
for (int i = 0; i < xmlMappings.Length; i++)
{
serializers[xmlMappings[i].Key!] = readerCodeGen.GenerateTypedSerializer(readMethodNames[i], writeMethodNames[i], xmlMappings[i], classes, baseSerializer, readerClass, writerClass);
if (!serializers.ContainsKey(xmlMappings[i].Key!))
{
serializers[xmlMappings[i].Key!] = readerCodeGen.GenerateTypedSerializer(readMethodNames[i], writeMethodNames[i], xmlMappings[i], classes, baseSerializer, readerClass, writerClass);
}
}
}
readerCodeGen.GenerateSerializerContract("XmlSerializerContract", xmlMappings, types!, readerClass, readMethodNames, writerClass, writeMethodNames, serializers);
readerCodeGen.GenerateSerializerContract("XmlSerializerContract", xmlMappings, types!, readerClass, readMethodNames, writerClass, writeMethodNames, serializers);

return writerType.Assembly;
return writerType.Assembly;
}
}

private static MethodInfo GetMethodFromType(
Expand Down Expand Up @@ -667,9 +675,9 @@ internal sealed class TempMethodDictionary : Dictionary<string, TempMethod>
internal sealed class TempAssemblyCacheKey
{
private readonly string? _ns;
private readonly object _type;
private readonly Type _type;

internal TempAssemblyCacheKey(string? ns, object type)
internal TempAssemblyCacheKey(string? ns, Type type)
{
_type = type;
_ns = ns;
Expand All @@ -691,33 +699,56 @@ public override int GetHashCode()

internal sealed class TempAssemblyCache
{
private ConditionalWeakTable<TempAssemblyCacheKey, TempAssembly> _cache = new ConditionalWeakTable<TempAssemblyCacheKey, TempAssembly>();
private Dictionary<TempAssemblyCacheKey, TempAssembly> _cache = new Dictionary<TempAssemblyCacheKey, TempAssembly>();
private ConditionalWeakTable<AssemblyLoadContext, Dictionary<TempAssemblyCacheKey, TempAssembly>> _collectibleCaches = new ConditionalWeakTable<AssemblyLoadContext, Dictionary<TempAssemblyCacheKey, TempAssembly>>();

internal TempAssembly? this[string? ns, object o]
internal TempAssembly? this[string? ns, Type t]
{
get
{
TempAssembly? tempAssembly;
_cache.TryGetValue(new TempAssemblyCacheKey(ns, o), out tempAssembly);
TempAssemblyCacheKey key = new TempAssemblyCacheKey(ns, t);

if (_cache.TryGetValue(key, out tempAssembly))
return tempAssembly;

var alc = AssemblyLoadContext.GetLoadContext(t.Assembly);
Dictionary<TempAssemblyCacheKey, TempAssembly>? cache;

if (alc != null && _collectibleCaches.TryGetValue(alc, out cache))
cache.TryGetValue(key, out tempAssembly);

return tempAssembly;
}
}

internal void Add(string? ns, object o, TempAssembly assembly)
internal void Add(string? ns, Type t, TempAssembly assembly)
{
TempAssemblyCacheKey key = new TempAssemblyCacheKey(ns, o);
var alc = AssemblyLoadContext.GetLoadContext(t.Assembly);
TempAssemblyCacheKey key = new TempAssemblyCacheKey(ns, t);

lock (this)
{
TempAssembly? tempAssembly;
if (_cache.TryGetValue(key, out tempAssembly) && tempAssembly == assembly)
TempAssembly? tempAssembly = this[ns, t];

if (tempAssembly == assembly)
return;
ConditionalWeakTable<TempAssemblyCacheKey, TempAssembly> _copy = new ConditionalWeakTable<TempAssemblyCacheKey, TempAssembly>(); // clone
foreach (KeyValuePair<TempAssemblyCacheKey, TempAssembly> kvp in _cache)

if (alc != null && alc.IsCollectible)
{
Dictionary<TempAssemblyCacheKey, TempAssembly>? collectibleCache;
if (!_collectibleCaches.TryGetValue(alc, out collectibleCache))
{
collectibleCache = new Dictionary<TempAssemblyCacheKey, TempAssembly>();
_collectibleCaches.Add(alc, collectibleCache);
}

collectibleCache.Add(key, assembly);
}
else
{
_copy.Add(kvp.Key, kvp.Value);
_cache.Add(key, assembly);
}
_copy.Add(key, assembly);
_cache = _copy;
}
}
}
Expand Down
Loading

0 comments on commit d9e96dc

Please sign in to comment.