Skip to content

Commit

Permalink
Options Source Gen Fixes (#91363)
Browse files Browse the repository at this point in the history
  • Loading branch information
tarekgh authored Aug 31, 2023
1 parent 97a98cd commit a8d5e7d
Show file tree
Hide file tree
Showing 9 changed files with 740 additions and 763 deletions.
85 changes: 66 additions & 19 deletions src/libraries/Microsoft.Extensions.Options/gen/Emitter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@

using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Threading;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.DotnetRuntime.Extensions;

namespace Microsoft.Extensions.Options.Generators
{
Expand All @@ -25,6 +27,7 @@ internal sealed class Emitter : EmitterBase
private string _staticValidationAttributeHolderClassFQN;
private string _staticValidatorHolderClassFQN;
private string _modifier;
private string _TryGetValueNullableAnnotation;

private sealed record StaticFieldInfo(string FieldTypeFQN, int FieldOrder, string FieldName, IList<string> InstantiationLines);

Expand All @@ -37,13 +40,14 @@ public Emitter(Compilation compilation, bool emitPreamble = true) : base(emitPre
else
{
_modifier = "internal";
string suffix = $"_{new Random().Next():X8}";
string suffix = $"_{GetNonRandomizedHashCode(compilation.SourceModule.Name):X8}";
_staticValidationAttributeHolderClassName += suffix;
_staticValidatorHolderClassName += suffix;
}

_staticValidationAttributeHolderClassFQN = $"global::{StaticFieldHolderClassesNamespace}.{_staticValidationAttributeHolderClassName}";
_staticValidatorHolderClassFQN = $"global::{StaticFieldHolderClassesNamespace}.{_staticValidatorHolderClassName}";
_TryGetValueNullableAnnotation = GetNullableAnnotationStringForTryValidateValueToUseInGeneratedCode(compilation);
}

public string Emit(
Expand All @@ -65,6 +69,31 @@ public string Emit(
return Capture();
}

/// <summary>
/// Returns the nullable annotation string to use in the code generation according to the first parameter of
/// <see cref="System.ComponentModel.DataAnnotations.Validator.TryValidateValue(object, ValidationContext, ICollection{ValidationResult}, IEnumerable{ValidationAttribute})"/> is nullable annotated.
/// </summary>
/// <param name="compilation">The <see cref="Compilation"/> to consider for analysis.</param>
/// <returns>"!" if the first parameter is not nullable annotated, otherwise an empty string.</returns>
/// <remarks>
/// In .NET 8.0 we have changed the nullable annotation on first parameter of the method cref="System.ComponentModel.DataAnnotations.Validator.TryValidateValue(object, ValidationContext, ICollection{ValidationResult}, IEnumerable{ValidationAttribute})"/>
/// The source generator need to detect if we need to append "!" to the first parameter of the method call when running on down-level versions.
/// </remarks>
private static string GetNullableAnnotationStringForTryValidateValueToUseInGeneratedCode(Compilation compilation)
{
INamedTypeSymbol? validatorTypeSymbol = compilation.GetBestTypeByMetadataName("System.ComponentModel.DataAnnotations.Validator");
if (validatorTypeSymbol is not null)
{
ImmutableArray<ISymbol> members = validatorTypeSymbol.GetMembers("TryValidateValue");
if (members.Length == 1 && members[0] is IMethodSymbol tryValidateValueMethod)
{
return tryValidateValueMethod.Parameters[0].NullableAnnotation == NullableAnnotation.NotAnnotated ? "!" : string.Empty;
}
}

return "!";
}

private void GenValidatorType(ValidatorType vt, ref Dictionary<string, StaticFieldInfo> staticValidationAttributesDict, ref Dictionary<string, StaticFieldInfo> staticValidatorsDict)
{
if (vt.Namespace.Length > 0)
Expand Down Expand Up @@ -161,7 +190,7 @@ private void GenModelSelfValidationIfNecessary(ValidatedModel modelToValidate)
{
if (modelToValidate.SelfValidates)
{
OutLn($"builder.AddResults(((global::System.ComponentModel.DataAnnotations.IValidatableObject)options).Validate(context));");
OutLn($"(builder ??= new()).AddResults(((global::System.ComponentModel.DataAnnotations.IValidatableObject)options).Validate(context));");
OutLn();
}
}
Expand All @@ -182,8 +211,7 @@ private void GenModelValidationMethod(

OutLn($"public {(makeStatic ? "static " : string.Empty)}global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, {modelToValidate.Name} options)");
OutOpenBrace();
OutLn($"var baseName = (string.IsNullOrEmpty(name) ? \"{modelToValidate.SimpleName}\" : name) + \".\";");
OutLn($"var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder();");
OutLn($"global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null;");
OutLn($"var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options);");

int capacity = modelToValidate.MembersToValidate.Max(static vm => vm.ValidationAttributes.Count);
Expand All @@ -199,33 +227,33 @@ private void GenModelValidationMethod(
{
if (vm.ValidationAttributes.Count > 0)
{
GenMemberValidation(vm, ref staticValidationAttributesDict, cleanListsBeforeUse);
GenMemberValidation(vm, modelToValidate.SimpleName, ref staticValidationAttributesDict, cleanListsBeforeUse);
cleanListsBeforeUse = true;
OutLn();
}

if (vm.TransValidatorType is not null)
{
GenTransitiveValidation(vm, ref staticValidatorsDict);
GenTransitiveValidation(vm, modelToValidate.SimpleName, ref staticValidatorsDict);
OutLn();
}

if (vm.EnumerationValidatorType is not null)
{
GenEnumerationValidation(vm, ref staticValidatorsDict);
GenEnumerationValidation(vm, modelToValidate.SimpleName, ref staticValidatorsDict);
OutLn();
}
}

GenModelSelfValidationIfNecessary(modelToValidate);
OutLn($"return builder.Build();");
OutLn($"return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build();");
OutCloseBrace();
}

private void GenMemberValidation(ValidatedMember vm, ref Dictionary<string, StaticFieldInfo> staticValidationAttributesDict, bool cleanListsBeforeUse)
private void GenMemberValidation(ValidatedMember vm, string modelName, ref Dictionary<string, StaticFieldInfo> staticValidationAttributesDict, bool cleanListsBeforeUse)
{
OutLn($"context.MemberName = \"{vm.Name}\";");
OutLn($"context.DisplayName = baseName + \"{vm.Name}\";");
OutLn($"context.DisplayName = string.IsNullOrEmpty(name) ? \"{modelName}.{vm.Name}\" : $\"{{name}}.{vm.Name}\";");

if (cleanListsBeforeUse)
{
Expand All @@ -239,9 +267,9 @@ private void GenMemberValidation(ValidatedMember vm, ref Dictionary<string, Stat
OutLn($"validationAttributes.Add({_staticValidationAttributeHolderClassFQN}.{staticValidationAttributeInstance.FieldName});");
}

OutLn($"if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.{vm.Name}!, context, validationResults, validationAttributes))");
OutLn($"if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.{vm.Name}{_TryGetValueNullableAnnotation}, context, validationResults, validationAttributes))");
OutOpenBrace();
OutLn($"builder.AddResults(validationResults);");
OutLn($"(builder ??= new()).AddResults(validationResults);");
OutCloseBrace();
}

Expand Down Expand Up @@ -305,7 +333,7 @@ private StaticFieldInfo GetOrAddStaticValidationAttribute(ref Dictionary<string,
return staticValidationAttributeInstance;
}

private void GenTransitiveValidation(ValidatedMember vm, ref Dictionary<string, StaticFieldInfo> staticValidatorsDict)
private void GenTransitiveValidation(ValidatedMember vm, string modelName, ref Dictionary<string, StaticFieldInfo> staticValidatorsDict)
{
string callSequence;
if (vm.TransValidateTypeIsSynthetic)
Expand All @@ -321,20 +349,22 @@ private void GenTransitiveValidation(ValidatedMember vm, ref Dictionary<string,

var valueAccess = (vm.IsNullable && vm.IsValueType) ? ".Value" : string.Empty;

var baseName = $"string.IsNullOrEmpty(name) ? \"{modelName}.{vm.Name}\" : $\"{{name}}.{vm.Name}\"";

if (vm.IsNullable)
{
OutLn($"if (options.{vm.Name} is not null)");
OutOpenBrace();
OutLn($"builder.AddResult({callSequence}.Validate(baseName + \"{vm.Name}\", options.{vm.Name}{valueAccess}));");
OutLn($"(builder ??= new()).AddResult({callSequence}.Validate({baseName}, options.{vm.Name}{valueAccess}));");
OutCloseBrace();
}
else
{
OutLn($"builder.AddResult({callSequence}.Validate(baseName + \"{vm.Name}\", options.{vm.Name}{valueAccess}));");
OutLn($"(builder ??= new()).AddResult({callSequence}.Validate({baseName}, options.{vm.Name}{valueAccess}));");
}
}

private void GenEnumerationValidation(ValidatedMember vm, ref Dictionary<string, StaticFieldInfo> staticValidatorsDict)
private void GenEnumerationValidation(ValidatedMember vm, string modelName, ref Dictionary<string, StaticFieldInfo> staticValidatorsDict)
{
var valueAccess = (vm.IsValueType && vm.IsNullable) ? ".Value" : string.Empty;
var enumeratedValueAccess = (vm.EnumeratedIsNullable && vm.EnumeratedIsValueType) ? ".Value" : string.Empty;
Expand Down Expand Up @@ -365,22 +395,25 @@ private void GenEnumerationValidation(ValidatedMember vm, ref Dictionary<string,
{
OutLn($"if (o is not null)");
OutOpenBrace();
OutLn($"builder.AddResult({callSequence}.Validate(baseName + $\"{vm.Name}[{{count}}]\", o{enumeratedValueAccess}));");
var propertyName = $"string.IsNullOrEmpty(name) ? $\"{modelName}.{vm.Name}[{{count}}]\" : $\"{{name}}.{vm.Name}[{{count}}]\"";
OutLn($"(builder ??= new()).AddResult({callSequence}.Validate({propertyName}, o{enumeratedValueAccess}));");
OutCloseBrace();

if (!vm.EnumeratedMayBeNull)
{
OutLn($"else");
OutOpenBrace();
OutLn($"builder.AddError(baseName + $\"{vm.Name}[{{count}}] is null\");");
var error = $"string.IsNullOrEmpty(name) ? $\"{modelName}.{vm.Name}[{{count}}] is null\" : $\"{{name}}.{vm.Name}[{{count}}] is null\"";
OutLn($"(builder ??= new()).AddError({error});");
OutCloseBrace();
}

OutLn($"count++;");
}
else
{
OutLn($"builder.AddResult({callSequence}.Validate(baseName + $\"{vm.Name}[{{count++}}]\", o{enumeratedValueAccess}));");
var propertyName = $"string.IsNullOrEmpty(name) ? $\"{modelName}.{vm.Name}[{{count++}}] is null\" : $\"{{name}}.{vm.Name}[{{count++}}] is null\"";
OutLn($"(builder ??= new()).AddResult({callSequence}.Validate({propertyName}, o{enumeratedValueAccess}));");
}

OutCloseBrace();
Expand All @@ -405,5 +438,19 @@ private StaticFieldInfo GetOrAddStaticValidator(ref Dictionary<string, StaticFie

return staticValidatorInstance;
}

/// <summary>
/// Returns a non-randomized hash code for the given string.
/// We always return a positive value.
/// </summary>
internal static int GetNonRandomizedHashCode(string s)
{
uint result = 2166136261u;
foreach (char c in s)
{
result = (c ^ result) * 16777619;
}
return Math.Abs((int)result);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

<ItemGroup>
<Compile Include="$(CoreLibSharedDir)System\Runtime\CompilerServices\IsExternalInit.cs" Link="Common\System\Runtime\CompilerServices\IsExternalInit.cs" />
<Compile Include="$(CommonPath)\Roslyn\GetBestTypeByMetadataName.cs" Link="Common\Roslyn\GetBestTypeByMetadataName.cs" />
<Compile Include="DiagDescriptors.cs" />
<Compile Include="DiagDescriptorsBase.cs" />
<Compile Include="Emitter.cs" />
Expand Down
7 changes: 7 additions & 0 deletions src/libraries/Microsoft.Extensions.Options/gen/Parser.cs
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,13 @@ private static bool HasOpenGenerics(ITypeSymbol type, out string genericType)
type = ((INamedTypeSymbol)type).TypeArguments[0];
}

// Check first if the type is IEnumerable<T> interface
if (SymbolEqualityComparer.Default.Equals(type.OriginalDefinition, _symbolHolder.GenericIEnumerableSymbol))
{
return ((INamedTypeSymbol)type).TypeArguments[0];
}

// Check first if the type implement IEnumerable<T> interface
foreach (var implementingInterface in type.AllInterfaces)
{
if (SymbolEqualityComparer.Default.Equals(implementingInterface.OriginalDefinition, _compilation.GetSpecialType(SpecialType.System_Collections_Generic_IEnumerable_T)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ internal sealed record class SymbolHolder(
INamedTypeSymbol DataTypeAttributeSymbol,
INamedTypeSymbol ValidateOptionsSymbol,
INamedTypeSymbol IValidatableObjectSymbol,
INamedTypeSymbol GenericIEnumerableSymbol,
INamedTypeSymbol TypeSymbol,
INamedTypeSymbol? ValidateObjectMembersAttributeSymbol,
INamedTypeSymbol? ValidateEnumeratedItemsAttributeSymbol);
INamedTypeSymbol ValidateObjectMembersAttributeSymbol,
INamedTypeSymbol ValidateEnumeratedItemsAttributeSymbol);
}
27 changes: 12 additions & 15 deletions src/libraries/Microsoft.Extensions.Options/gen/SymbolLoader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,35 +15,33 @@ internal static class SymbolLoader
internal const string TypeOfType = "System.Type";
internal const string ValidateObjectMembersAttribute = "Microsoft.Extensions.Options.ValidateObjectMembersAttribute";
internal const string ValidateEnumeratedItemsAttribute = "Microsoft.Extensions.Options.ValidateEnumeratedItemsAttribute";
internal const string GenericIEnumerableType = "System.Collections.Generic.IEnumerable`1";

public static bool TryLoad(Compilation compilation, out SymbolHolder? symbolHolder)
{
INamedTypeSymbol? GetSymbol(string metadataName, bool optional = false)
{
var symbol = compilation.GetTypeByMetadataName(metadataName);
if (symbol == null && !optional)
{
return null;
}

return symbol;
}
INamedTypeSymbol? GetSymbol(string metadataName) => compilation.GetTypeByMetadataName(metadataName);

// required
var optionsValidatorSymbol = GetSymbol(OptionsValidatorAttribute);
var validationAttributeSymbol = GetSymbol(ValidationAttribute);
var dataTypeAttributeSymbol = GetSymbol(DataTypeAttribute);
var ivalidatableObjectSymbol = GetSymbol(IValidatableObjectType);
var validateOptionsSymbol = GetSymbol(IValidateOptionsType);
var genericIEnumerableSymbol = GetSymbol(GenericIEnumerableType);
var typeSymbol = GetSymbol(TypeOfType);
var validateObjectMembersAttribute = GetSymbol(ValidateObjectMembersAttribute);
var validateEnumeratedItemsAttribute = GetSymbol(ValidateEnumeratedItemsAttribute);

#pragma warning disable S1067 // Expressions should not be too complex
if (optionsValidatorSymbol == null ||
validationAttributeSymbol == null ||
dataTypeAttributeSymbol == null ||
ivalidatableObjectSymbol == null ||
validateOptionsSymbol == null ||
typeSymbol == null)
genericIEnumerableSymbol == null ||
typeSymbol == null ||
validateObjectMembersAttribute == null ||
validateEnumeratedItemsAttribute == null)
{
symbolHolder = default;
return false;
Expand All @@ -56,11 +54,10 @@ public static bool TryLoad(Compilation compilation, out SymbolHolder? symbolHold
dataTypeAttributeSymbol,
validateOptionsSymbol,
ivalidatableObjectSymbol,
genericIEnumerableSymbol,
typeSymbol,

// optional
GetSymbol(ValidateObjectMembersAttribute, optional: true),
GetSymbol(ValidateEnumeratedItemsAttribute, optional: true));
validateObjectMembersAttribute,
validateEnumeratedItemsAttribute);

return true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,31 +70,30 @@ partial struct MyOptionsValidator
[global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")]
public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::HelloWorld.MyOptions options)
{
var baseName = (string.IsNullOrEmpty(name) ? "MyOptions" : name) + ".";
var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder();
global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null;
var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options);
var validationResults = new global::System.Collections.Generic.List<global::System.ComponentModel.DataAnnotations.ValidationResult>();
var validationAttributes = new global::System.Collections.Generic.List<global::System.ComponentModel.DataAnnotations.ValidationAttribute>(1);
context.MemberName = "Val1";
context.DisplayName = baseName + "Val1";
context.DisplayName = string.IsNullOrEmpty(name) ? "MyOptions.Val1" : $"{name}.Val1";
validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1);
if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val1!, context, validationResults, validationAttributes))
if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val1, context, validationResults, validationAttributes))
{
builder.AddResults(validationResults);
(builder ??= new()).AddResults(validationResults);
}
context.MemberName = "Val2";
context.DisplayName = baseName + "Val2";
context.DisplayName = string.IsNullOrEmpty(name) ? "MyOptions.Val2" : $"{name}.Val2";
validationResults.Clear();
validationAttributes.Clear();
validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2);
if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val2!, context, validationResults, validationAttributes))
if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val2, context, validationResults, validationAttributes))
{
builder.AddResults(validationResults);
(builder ??= new()).AddResults(validationResults);
}
return builder.Build();
return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build();
}
}
}
Expand Down
Loading

0 comments on commit a8d5e7d

Please sign in to comment.