diff --git a/src/EFCore/Metadata/Conventions/NonNullableConventionBase.cs b/src/EFCore/Metadata/Conventions/NonNullableConventionBase.cs
index b43e6a90843..119d140388a 100644
--- a/src/EFCore/Metadata/Conventions/NonNullableConventionBase.cs
+++ b/src/EFCore/Metadata/Conventions/NonNullableConventionBase.cs
@@ -3,11 +3,12 @@
using System;
using System.Collections.Generic;
+using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Reflection;
-using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.Metadata.Builders;
using Microsoft.EntityFrameworkCore.Metadata.Conventions.Infrastructure;
+using JetbrainsNotNull = JetBrains.Annotations.NotNullAttribute;
namespace Microsoft.EntityFrameworkCore.Metadata.Conventions
{
@@ -28,7 +29,7 @@ public abstract class NonNullableConventionBase : IModelFinalizedConvention
/// Creates a new instance of .
///
/// Parameter object containing dependencies for this convention.
- protected NonNullableConventionBase([NotNull] ProviderConventionSetBuilderDependencies dependencies)
+ protected NonNullableConventionBase([JetbrainsNotNull] ProviderConventionSetBuilderDependencies dependencies)
{
Dependencies = dependencies;
}
@@ -38,27 +39,6 @@ protected NonNullableConventionBase([NotNull] ProviderConventionSetBuilderDepend
///
protected virtual ProviderConventionSetBuilderDependencies Dependencies { get; }
- private byte? GetNullabilityContextFlag(NonNullabilityConventionState state, Attribute[] attributes)
- {
- if (attributes.FirstOrDefault(a => a.GetType().FullName == NullableContextAttributeFullName) is Attribute attribute)
- {
- var attributeType = attribute.GetType();
-
- if (attributeType != state.NullableContextAttrType)
- {
- state.NullableContextFlagFieldInfo = attributeType.GetField("Flag");
- state.NullableContextAttrType = attributeType;
- }
-
- if (state.NullableContextFlagFieldInfo?.GetValue(attribute) is byte flag)
- {
- return flag;
- }
- }
-
- return null;
- }
-
///
/// Returns a value indicating whether the member type is a non-nullable reference type.
///
@@ -66,8 +46,8 @@ protected NonNullableConventionBase([NotNull] ProviderConventionSetBuilderDepend
/// The member info.
/// true if the member type is a non-nullable reference type.
protected virtual bool IsNonNullableRefType(
- [NotNull] IConventionModelBuilder modelBuilder,
- [NotNull] MemberInfo memberInfo)
+ [JetbrainsNotNull] IConventionModelBuilder modelBuilder,
+ [JetbrainsNotNull] MemberInfo memberInfo)
{
if (memberInfo.GetMemberType().IsValueType)
{
@@ -76,6 +56,19 @@ protected virtual bool IsNonNullableRefType(
var state = GetOrInitializeState(modelBuilder);
+ // First check for [MaybeNull] on the return value. If it exists, the member is nullable.
+ var isMaybeNull = memberInfo switch
+ {
+ FieldInfo f => f.GetCustomAttribute() != null,
+ PropertyInfo p => p.GetMethod?.ReturnParameter?.GetCustomAttribute() != null,
+ _ => false
+ };
+
+ if (isMaybeNull)
+ {
+ return false;
+ }
+
// For C# 8.0 nullable types, the C# currently synthesizes a NullableAttribute that expresses nullability into assemblies
// it produces. If the model is spread across more than one assembly, there will be multiple versions of this attribute,
// so look for it by name, caching to avoid reflection on every check.
@@ -83,7 +76,7 @@ protected virtual bool IsNonNullableRefType(
// First look for NullableAttribute on the member itself
if (Attribute.GetCustomAttributes(memberInfo)
- .FirstOrDefault(a => a.GetType().FullName == NullableAttributeFullName) is Attribute attribute)
+ .FirstOrDefault(a => a.GetType().FullName == NullableAttributeFullName) is Attribute attribute)
{
var attributeType = attribute.GetType();
@@ -103,33 +96,32 @@ protected virtual bool IsNonNullableRefType(
var type = memberInfo.DeclaringType;
if (type != null)
{
- if (state.TypeNonNullabilityContextCache.TryGetValue(type, out var cachedTypeNonNullable))
+ if (state.TypeCache.TryGetValue(type, out var cachedTypeNonNullable))
{
return cachedTypeNonNullable;
}
- var typeContextFlag = GetNullabilityContextFlag(state, Attribute.GetCustomAttributes(type));
- if (typeContextFlag.HasValue)
+ if (Attribute.GetCustomAttributes(type)
+ .FirstOrDefault(a => a.GetType().FullName == NullableContextAttributeFullName) is Attribute contextAttr)
{
- return state.TypeNonNullabilityContextCache[type] = typeContextFlag.Value == 1;
+ var attributeType = contextAttr.GetType();
+
+ if (attributeType != state.NullableContextAttrType)
+ {
+ state.NullableContextFlagFieldInfo = attributeType.GetField("Flag");
+ state.NullableContextAttrType = attributeType;
+ }
+
+ if (state.NullableContextFlagFieldInfo?.GetValue(contextAttr) is byte flag)
+ {
+ return state.TypeCache[type] = flag == 1;
+ }
}
- }
-
- // Not found at the type level, try at the module level
- var module = memberInfo.Module;
- if (!state.ModuleNonNullabilityContextCache.TryGetValue(module, out var moduleNonNullable))
- {
- var moduleContextFlag = GetNullabilityContextFlag(state, Attribute.GetCustomAttributes(memberInfo.Module));
- moduleNonNullable = state.ModuleNonNullabilityContextCache[module] =
- moduleContextFlag.HasValue && moduleContextFlag == 1;
- }
- if (type != null)
- {
- state.TypeNonNullabilityContextCache[type] = moduleNonNullable;
+ return state.TypeCache[type] = false;
}
- return moduleNonNullable;
+ return false;
}
private NonNullabilityConventionState GetOrInitializeState(IConventionModelBuilder modelBuilder)
@@ -152,8 +144,7 @@ private class NonNullabilityConventionState
public Type NullableContextAttrType;
public FieldInfo NullableFlagsFieldInfo;
public FieldInfo NullableContextFlagFieldInfo;
- public Dictionary TypeNonNullabilityContextCache { get; } = new Dictionary();
- public Dictionary ModuleNonNullabilityContextCache { get; } = new Dictionary();
+ public Dictionary TypeCache { get; } = new Dictionary();
}
}
}
diff --git a/test/EFCore.Tests/Metadata/Conventions/NonNullableReferencePropertyConventionTest.cs b/test/EFCore.Tests/Metadata/Conventions/NonNullableReferencePropertyConventionTest.cs
index db6bcc8b7bf..ef4aa7aa29d 100644
--- a/test/EFCore.Tests/Metadata/Conventions/NonNullableReferencePropertyConventionTest.cs
+++ b/test/EFCore.Tests/Metadata/Conventions/NonNullableReferencePropertyConventionTest.cs
@@ -3,6 +3,7 @@
using System;
using System.ComponentModel.DataAnnotations;
+using System.Diagnostics.CodeAnalysis;
using Microsoft.EntityFrameworkCore.Metadata.Builders;
using Microsoft.EntityFrameworkCore.Metadata.Conventions.Infrastructure;
using Microsoft.EntityFrameworkCore.Metadata.Conventions.Internal;
@@ -49,32 +50,35 @@ public void Non_nullability_sets_is_nullable_with_conventional_builder()
var modelBuilder = CreateModelBuilder();
var entityTypeBuilder = modelBuilder.Entity();
- Assert.False(entityTypeBuilder.Property(e => e.Name).Metadata.IsNullable);
+ Assert.False(entityTypeBuilder.Property(e => e.NonNullable).Metadata.IsNullable);
}
[ConditionalTheory]
- [InlineData(nameof(A.NullAwareNonNullable), false)]
- [InlineData(nameof(A.NullAwareNullable), true)]
- [InlineData(nameof(A.NullObliviousNonNullable), true)]
- [InlineData(nameof(A.NullObliviousNullable), true)]
- [InlineData(nameof(A.RequiredAndNullable), false)]
- public void Reference_nullability_sets_is_nullable_correctly1(string propertyName, bool expectedNullable)
+ [InlineData(typeof(A), nameof(A.NonNullable), false)]
+ [InlineData(typeof(A), nameof(A.Nullable), true)]
+
+ [InlineData(typeof(A), nameof(A.NonNullablePropertyMaybeNull), true)]
+ [InlineData(typeof(A), nameof(A.NonNullablePropertyAllowNull), false)]
+ [InlineData(typeof(A), nameof(A.NullablePropertyNotNull), true)]
+ [InlineData(typeof(A), nameof(A.NullablePropertyDisallowNull), true)]
+
+ [InlineData(typeof(A), nameof(A.NonNullableFieldMaybeNull), true)]
+ [InlineData(typeof(A), nameof(A.NonNullableFieldAllowNull), false)]
+ [InlineData(typeof(A), nameof(A.NullableFieldNotNull), true)]
+ [InlineData(typeof(A), nameof(A.NullableFieldDisallowNull), true)]
+
+ [InlineData(typeof(A), nameof(A.RequiredAndNullable), false)]
+ [InlineData(typeof(A), nameof(A.NullObliviousNonNullable), true)]
+ [InlineData(typeof(A), nameof(A.NullObliviousNullable), true)]
+
+ [InlineData(typeof(B), nameof(B.NonNullableValueType), false)]
+ [InlineData(typeof(B), nameof(B.NullableValueType), true)]
+ [InlineData(typeof(B), nameof(B.NonNullableRefType), false)]
+ [InlineData(typeof(B), nameof(B.NullableRefType), true)]
+ public void Reference_nullability_sets_is_nullable_correctly(Type type, string propertyName, bool expectedNullable)
{
var modelBuilder = CreateModelBuilder();
- var entityTypeBuilder = modelBuilder.Entity();
-
- Assert.Equal(expectedNullable, entityTypeBuilder.Property(propertyName).Metadata.IsNullable);
- }
-
- [ConditionalTheory]
- [InlineData(nameof(B.NonNullableValueType), false)]
- [InlineData(nameof(B.NullableValueType), true)]
- [InlineData(nameof(B.NonNullableRefType), false)]
- [InlineData(nameof(B.NullableRefType), true)]
- public void Reference_nullability_sets_is_nullable_correctly2(string propertyName, bool expectedNullable)
- {
- var modelBuilder = CreateModelBuilder();
- var entityTypeBuilder = modelBuilder.Entity();
+ var entityTypeBuilder = modelBuilder.Entity(type);
Assert.Equal(expectedNullable, entityTypeBuilder.Property(propertyName).Metadata.IsNullable);
}
@@ -107,10 +111,26 @@ private class A
public int Id { get; set; }
#nullable enable
- public string Name { get; set; } = "";
-
- public string NullAwareNonNullable { get; set; } = "";
- public string? NullAwareNullable { get; set; }
+ public string NonNullable { get; set; } = "";
+ public string? Nullable { get; set; }
+
+ [MaybeNull]
+ public string NonNullablePropertyMaybeNull { get; set; } = "";
+ [AllowNull]
+ public string NonNullablePropertyAllowNull { get; set; } = "";
+ [NotNull]
+ public string? NullablePropertyNotNull { get; set; } = "";
+ [DisallowNull]
+ public string? NullablePropertyDisallowNull { get; set; } = "";
+
+ [MaybeNull]
+ public string NonNullableFieldMaybeNull = "";
+ [AllowNull]
+ public string NonNullableFieldAllowNull = "";
+ [NotNull]
+ public string? NullableFieldNotNull = "";
+ [DisallowNull]
+ public string? NullableFieldDisallowNull = "";
[Required]
public string? RequiredAndNullable { get; set; }