Skip to content

Commit

Permalink
Enable nullable context for extensions generator (#878)
Browse files Browse the repository at this point in the history
  • Loading branch information
atifaziz authored Nov 12, 2022
1 parent 382bcf8 commit ee8abeb
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
<OutputType>Exe</OutputType>
<TargetFramework>net7.0</TargetFramework>
<IsPackable>false</IsPackable>
<Nullable>enable</Nullable>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Microsoft.CodeAnalysis.CSharp" Version="4.3.1" />
Expand Down
31 changes: 13 additions & 18 deletions bld/ExtensionsGenerator/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ static void Run(IEnumerable<string> args)
{
var dir = Directory.GetCurrentDirectory();

string includePattern = null;
string excludePattern = null;
string? includePattern = null;
string? excludePattern = null;
var debug = false;
var usings = new List<string>();
var noClassLead = false;
Expand Down Expand Up @@ -88,7 +88,7 @@ static Exception MissingArgValue() =>
}

static Func<string, bool>
PredicateFromPattern(string pattern, bool @default) =>
PredicateFromPattern(string? pattern, bool @default) =>
string.IsNullOrEmpty(pattern)
? delegate { return @default; }
: new Func<string, bool>(new Regex(pattern).IsMatch);
Expand Down Expand Up @@ -132,7 +132,7 @@ from cd in
.SyntaxTree
.GetCompilationUnitRoot()
.DescendantNodes().OfType<ClassDeclarationSyntax>()
where (string) cd.Identifier.Value == "MoreEnumerable"
where cd.Identifier.Value is "MoreEnumerable"
//
// Get all method declarations where method:
//
Expand All @@ -142,10 +142,9 @@ from cd in
// - isn't marked as being obsolete
//
from md in cd.DescendantNodes().OfType<MethodDeclarationSyntax>()
let mn = (string) md.Identifier.Value
where md.ParameterList.Parameters.Count > 0
&& md.ParameterList.Parameters.First().Modifiers.Any(m => (string)m.Value == "this")
&& md.Modifiers.Any(m => (string)m.Value == "public")
&& md.ParameterList.Parameters.First().Modifiers.Any(m => m.Value is "this")
&& md.Modifiers.Any(m => m.Value is "public")
&& md.AttributeLists.SelectMany(al => al.Attributes).All(a => a.Name.ToString() != "Obsolete")
//
// Build a dictionary of type abbreviations (e.g. TSource -> a,
Expand All @@ -172,7 +171,7 @@ where md.ParameterList.Parameters.Count > 0
ParameterCount = md.ParameterList.Parameters.Count,
SortableParameterTypes =
from p in md.ParameterList.Parameters
select CreateTypeKey(p.Type,
select CreateTypeKey(p.Type ?? throw new NullReferenceException(),
n => typeParameterAbbreviationByName is { } someTypeParameterAbbreviationByName
&& someTypeParameterAbbreviationByName.TryGetValue(n, out var a) ? a : null),
}
Expand Down Expand Up @@ -251,7 +250,8 @@ from ns in baseImports.Concat(usings)
var classes =
from md in q
select md.Method.Syntax into md
group md by (string) md.Identifier.Value into g
group md by md.Identifier.Value is string id ? id : throw new NullReferenceException()
into g
select new
{
Name = g.Key,
Expand Down Expand Up @@ -334,8 +334,7 @@ namespace MoreLinq.Extensions
.Replace("\n", Environment.NewLine));
}

static TypeKey CreateTypeKey(TypeSyntax root,
Func<string, TypeKey> abbreviator = null)
static TypeKey CreateTypeKey(TypeSyntax root, Func<string, TypeKey?> abbreviator)
{
return Walk(root ?? throw new ArgumentNullException(nameof(root)));

Expand All @@ -359,12 +358,8 @@ select Walk(te.Type))),
};
}

static T Read<T>(IEnumerator<T> e, Func<Exception> errorFactory = null)
{
if (!e.MoveNext())
throw errorFactory?.Invoke() ?? new InvalidOperationException();
return e.Current;
}
static T Read<T>(IEnumerator<T> e, Func<Exception> errorFactory) =>
e.MoveNext() ? e.Current : throw errorFactory();

//
// Logical type nodes designed to be structurally sortable based on:
Expand All @@ -382,7 +377,7 @@ abstract class TypeKey : IComparable<TypeKey>
public string Name { get; }
public abstract ImmutableList<TypeKey> Parameters { get; }

public virtual int CompareTo(TypeKey other)
public virtual int CompareTo(TypeKey? other)
=> ReferenceEquals(this, other) ? 0
: other == null ? 1
: Parameters.Count.CompareTo(other.Parameters.Count) is var lc and not 0 ? lc
Expand Down

0 comments on commit ee8abeb

Please sign in to comment.