Skip to content

Commit

Permalink
fixed lambda return type nullability inference
Browse files Browse the repository at this point in the history
  • Loading branch information
jkurdek committed Apr 8, 2024
1 parent 4923ac7 commit 766f36d
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 17 deletions.
19 changes: 9 additions & 10 deletions src/Controls/src/BindingSourceGen/BindingSourceGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,10 @@ static bool IsSetBindingMethod(SyntaxNode node)
static BindingDiagnosticsWrapper GetBindingForGeneration(GeneratorSyntaxContext context, CancellationToken t)
{
var diagnostics = new List<Diagnostic>();
var invocation = (InvocationExpressionSyntax)context.Node;
NullableContext nullableContext = context.SemanticModel.GetNullableContext(context.Node.Span.Start);
var enabledNullable = (nullableContext & NullableContext.Enabled) == NullableContext.Enabled;

var invocation = (InvocationExpressionSyntax)context.Node;
var method = (MemberAccessExpressionSyntax)invocation.Expression;

var sourceCodeLocation = new SourceCodeLocation(
Expand All @@ -82,26 +84,23 @@ static BindingDiagnosticsWrapper GetBindingForGeneration(GeneratorSyntaxContext
return ReportDiagnostics(lambdaDiagnostics);
}

NullableContext nullableContext = context.SemanticModel.GetNullableContext(context.Node.Span.Start);
var enabledNullable = (nullableContext & NullableContext.Enabled) == NullableContext.Enabled;
var lambdaTypeInfo = context.SemanticModel.GetTypeInfo(lambdaBody, t);
if (lambdaTypeInfo.Type == null)
{
return ReportDiagnostics([DiagnosticsFactory.UnableToResolvePath(lambdaBody.GetLocation())]); // TODO: New diagnostic
}

var pathParser = new PathParser(context);
var (pathDiagnostics, parts) = pathParser.ParsePath(lambdaBody);

if (pathDiagnostics.Length > 0)
{
return ReportDiagnostics(pathDiagnostics);
}

// Sometimes analysing just the return type of the lambda is not enough. TODO: Refactor
// var propertyType = BindingGenerationUtilities.CreateTypeNameFromITypeSymbol(lambdaSymbol.ReturnType, enabledNullable);
// var lastMember = parts.Last() is Cast cast ? cast.Part : parts.Last();
// propertyType = propertyType with { IsNullable = lastMember is ConditionalAccess || propertyType.IsNullable };

var codeWriterBinding = new CodeWriterBinding(
Location: sourceCodeLocation,
SourceType: BindingGenerationUtilities.CreateTypeNameFromITypeSymbol(lambdaSymbol.Parameters[0].Type, enabledNullable),
PropertyType: BindingGenerationUtilities.CreateTypeNameFromITypeSymbol(lambdaSymbol.ReturnType, enabledNullable),
PropertyType: BindingGenerationUtilities.CreateTypeNameFromITypeSymbol(lambdaTypeInfo.Type, enabledNullable),
Path: parts.ToArray(),
GenerateSetter: false //TODO: Implement
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ public void GenerateBindingWithNullableSourceReferenceAndNullableReferenceElemen
AssertExtensions.BindingsAreEqual(expectedBinding, codeGeneratorResult);
}

[Fact(Skip = "Require checking path for elements that can be null")]
[Fact]
public void GenerateBindingWithNullablePropertyReferenceWhenNullableEnabled()
{
var source = """
Expand Down Expand Up @@ -299,7 +299,7 @@ class Foo
AssertExtensions.BindingsAreEqual(expectedBinding, codeGeneratorResult);
}

[Fact(Skip = "Requires checking path for casts")]
[Fact]
public void GenerateBindingWhenGetterContainsSimpleReferenceTypeCast()
{
var source = """
Expand All @@ -309,15 +309,15 @@ public void GenerateBindingWhenGetterContainsSimpleReferenceTypeCast()
class Foo
{
public object Value { get; set; }
public object Value { get; set; } = "Value";
}
""";

var codeGeneratorResult = SourceGenHelpers.Run(source);
var expectedBinding = new CodeWriterBinding(
new SourceCodeLocation(@"Path\To\Program.cs", 3, 7),
new TypeDescription("global::Foo"),
new TypeDescription("string", IsNullable: true), // May be hard
new TypeDescription("string", IsNullable: true),
[
new Cast(
new MemberAccess("Value"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ file static class GeneratedBindableObjectExtensions
public static void SetBinding1(
this BindableObject bindableObject,
BindableProperty bidnableProperty,
Func<global::MyNamespace.MySourceClass, global::MyNamespace.MyPropertyClass> getter,
Func<global::MyNamespace.MySourceClass, global::MyNamespace.MyPropertyClass?> getter,
BindingMode mode = BindingMode.Default,
IValueConverter? converter = null,
object? converterParameter = null,
Expand All @@ -196,7 +196,7 @@ public static void SetBinding1(
object? fallbackValue = null,
object? targetNullValue = null)
{
var binding = new TypedBinding<global::MyNamespace.MySourceClass, global::MyNamespace.MyPropertyClass>(
var binding = new TypedBinding<global::MyNamespace.MySourceClass, global::MyNamespace.MyPropertyClass?>(
getter: source => (getter(source), true),
setter: null,
handlers: new Tuple<Func<global::MyNamespace.MySourceClass, object?>, string>[]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ internal static CodeGeneratorResult Run(string source)
var result = driver.RunGeneratorsAndUpdateCompilation(inputCompilation, out Compilation compilation, out _).GetRunResult().Results.Single();

var generatedCodeDiagnostic = compilation.GetDiagnostics();
var generatedCode = result.GeneratedSources.Single().SourceText.ToString();
var generatedCode = result.GeneratedSources.Length == 1 ? result.GeneratedSources.Single().SourceText.ToString() : "";

var trackedSteps = result.TrackedSteps;

Expand Down

0 comments on commit 766f36d

Please sign in to comment.