From 766f36d8fd87413162c151218c29055fe7efe33d Mon Sep 17 00:00:00 2001 From: Jeremi Kurdek Date: Mon, 8 Apr 2024 13:05:03 +0200 Subject: [PATCH] fixed lambda return type nullability inference --- .../BindingSourceGenerator.cs | 19 +++++++++---------- .../BindingRepresentationGenTests.cs | 8 ++++---- .../IntegrationTests.cs | 4 ++-- .../SourceGenHelpers.cs | 2 +- 4 files changed, 16 insertions(+), 17 deletions(-) diff --git a/src/Controls/src/BindingSourceGen/BindingSourceGenerator.cs b/src/Controls/src/BindingSourceGen/BindingSourceGenerator.cs index c2add5c1ed0e..aacc369ccdd8 100644 --- a/src/Controls/src/BindingSourceGen/BindingSourceGenerator.cs +++ b/src/Controls/src/BindingSourceGen/BindingSourceGenerator.cs @@ -58,8 +58,10 @@ static bool IsSetBindingMethod(SyntaxNode node) static BindingDiagnosticsWrapper GetBindingForGeneration(GeneratorSyntaxContext context, CancellationToken t) { var diagnostics = new List(); - var invocation = (InvocationExpressionSyntax)context.Node; + NullableContext nullableContext = context.SemanticModel.GetNullableContext(context.Node.Span.Start); + var enabledNullable = (nullableContext & NullableContext.Enabled) == NullableContext.Enabled; + var invocation = (InvocationExpressionSyntax)context.Node; var method = (MemberAccessExpressionSyntax)invocation.Expression; var sourceCodeLocation = new SourceCodeLocation( @@ -82,26 +84,23 @@ static BindingDiagnosticsWrapper GetBindingForGeneration(GeneratorSyntaxContext return ReportDiagnostics(lambdaDiagnostics); } - NullableContext nullableContext = context.SemanticModel.GetNullableContext(context.Node.Span.Start); - var enabledNullable = (nullableContext & NullableContext.Enabled) == NullableContext.Enabled; + var lambdaTypeInfo = context.SemanticModel.GetTypeInfo(lambdaBody, t); + if (lambdaTypeInfo.Type == null) + { + return ReportDiagnostics([DiagnosticsFactory.UnableToResolvePath(lambdaBody.GetLocation())]); // TODO: New diagnostic + } var pathParser = new PathParser(context); var (pathDiagnostics, parts) = pathParser.ParsePath(lambdaBody); - if (pathDiagnostics.Length > 0) { return ReportDiagnostics(pathDiagnostics); } - // Sometimes analysing just the return type of the lambda is not enough. TODO: Refactor - // var propertyType = BindingGenerationUtilities.CreateTypeNameFromITypeSymbol(lambdaSymbol.ReturnType, enabledNullable); - // var lastMember = parts.Last() is Cast cast ? cast.Part : parts.Last(); - // propertyType = propertyType with { IsNullable = lastMember is ConditionalAccess || propertyType.IsNullable }; - var codeWriterBinding = new CodeWriterBinding( Location: sourceCodeLocation, SourceType: BindingGenerationUtilities.CreateTypeNameFromITypeSymbol(lambdaSymbol.Parameters[0].Type, enabledNullable), - PropertyType: BindingGenerationUtilities.CreateTypeNameFromITypeSymbol(lambdaSymbol.ReturnType, enabledNullable), + PropertyType: BindingGenerationUtilities.CreateTypeNameFromITypeSymbol(lambdaTypeInfo.Type, enabledNullable), Path: parts.ToArray(), GenerateSetter: false //TODO: Implement ); diff --git a/src/Controls/tests/BindingSourceGen.UnitTests/BindingRepresentationGenTests.cs b/src/Controls/tests/BindingSourceGen.UnitTests/BindingRepresentationGenTests.cs index d0915bbea57f..117ae984c7cc 100644 --- a/src/Controls/tests/BindingSourceGen.UnitTests/BindingRepresentationGenTests.cs +++ b/src/Controls/tests/BindingSourceGen.UnitTests/BindingRepresentationGenTests.cs @@ -155,7 +155,7 @@ public void GenerateBindingWithNullableSourceReferenceAndNullableReferenceElemen AssertExtensions.BindingsAreEqual(expectedBinding, codeGeneratorResult); } - [Fact(Skip = "Require checking path for elements that can be null")] + [Fact] public void GenerateBindingWithNullablePropertyReferenceWhenNullableEnabled() { var source = """ @@ -299,7 +299,7 @@ class Foo AssertExtensions.BindingsAreEqual(expectedBinding, codeGeneratorResult); } - [Fact(Skip = "Requires checking path for casts")] + [Fact] public void GenerateBindingWhenGetterContainsSimpleReferenceTypeCast() { var source = """ @@ -309,7 +309,7 @@ public void GenerateBindingWhenGetterContainsSimpleReferenceTypeCast() class Foo { - public object Value { get; set; } + public object Value { get; set; } = "Value"; } """; @@ -317,7 +317,7 @@ class Foo var expectedBinding = new CodeWriterBinding( new SourceCodeLocation(@"Path\To\Program.cs", 3, 7), new TypeDescription("global::Foo"), - new TypeDescription("string", IsNullable: true), // May be hard + new TypeDescription("string", IsNullable: true), [ new Cast( new MemberAccess("Value"), diff --git a/src/Controls/tests/BindingSourceGen.UnitTests/IntegrationTests.cs b/src/Controls/tests/BindingSourceGen.UnitTests/IntegrationTests.cs index f8b1c47218e5..107bd8d4eb13 100644 --- a/src/Controls/tests/BindingSourceGen.UnitTests/IntegrationTests.cs +++ b/src/Controls/tests/BindingSourceGen.UnitTests/IntegrationTests.cs @@ -187,7 +187,7 @@ file static class GeneratedBindableObjectExtensions public static void SetBinding1( this BindableObject bindableObject, BindableProperty bidnableProperty, - Func getter, + Func getter, BindingMode mode = BindingMode.Default, IValueConverter? converter = null, object? converterParameter = null, @@ -196,7 +196,7 @@ public static void SetBinding1( object? fallbackValue = null, object? targetNullValue = null) { - var binding = new TypedBinding( + var binding = new TypedBinding( getter: source => (getter(source), true), setter: null, handlers: new Tuple, string>[] diff --git a/src/Controls/tests/BindingSourceGen.UnitTests/SourceGenHelpers.cs b/src/Controls/tests/BindingSourceGen.UnitTests/SourceGenHelpers.cs index fc2672923610..66a66f9fd538 100644 --- a/src/Controls/tests/BindingSourceGen.UnitTests/SourceGenHelpers.cs +++ b/src/Controls/tests/BindingSourceGen.UnitTests/SourceGenHelpers.cs @@ -33,7 +33,7 @@ internal static CodeGeneratorResult Run(string source) var result = driver.RunGeneratorsAndUpdateCompilation(inputCompilation, out Compilation compilation, out _).GetRunResult().Results.Single(); var generatedCodeDiagnostic = compilation.GetDiagnostics(); - var generatedCode = result.GeneratedSources.Single().SourceText.ToString(); + var generatedCode = result.GeneratedSources.Length == 1 ? result.GeneratedSources.Single().SourceText.ToString() : ""; var trackedSteps = result.TrackedSteps;