Skip to content

Commit

Permalink
Fix buggy code for ordering switch statement cases for messages (#147)
Browse files Browse the repository at this point in the history
  • Loading branch information
martinothamar authored Apr 8, 2024
1 parent a52abef commit f036747
Show file tree
Hide file tree
Showing 7 changed files with 2,025 additions and 54 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using System.Collections.Immutable;
using System.Diagnostics;
using System.Runtime.InteropServices.ComTypes;
using Mediator.SourceGenerator.Extensions;
using Microsoft.CodeAnalysis.CSharp;

Expand Down Expand Up @@ -264,22 +266,35 @@ public void Analyze()
}
}

private sealed class InheritanceComparer : IComparer<INamedTypeSymbol>
private static ImmutableEquatableArray<TModel> ToModelsSortedByInheritanceDepth<TSource, TModel>(
HashSet<TSource> source,
Func<TSource, TModel> selector
)
where TSource : SymbolMetadata<TSource>
where TModel : SymbolMetadataModel, IEquatable<TModel>
{
public int Compare(INamedTypeSymbol x, INamedTypeSymbol y)
var analysis = new (TSource Message, int Depth)[source.Count];
int i = 0;
foreach (var message in source)
{
while (x.BaseType is not null)
var baseType = message.Symbol.BaseType;
int depth = 0;
while (baseType is not null && baseType.SpecialType != SpecialType.System_Object)
{
if (x.BaseType.SpecialType == SpecialType.System_Object)
break;

if (SymbolEqualityComparer.Default.Equals(x.BaseType, y))
return -1;
x = x.BaseType;
depth++;
baseType = baseType.BaseType;
}

return x.GetTypeSymbolFullName().CompareTo(y.GetTypeSymbolFullName());
Debug.Assert(i < source.Count);
analysis[i++] = (message, depth);
}

Array.Sort(analysis, (x, y) => y.Depth.CompareTo(x.Depth));
var models = new TModel[source.Count];
for (i = 0; i < source.Count; i++)
models[i] = selector(analysis[i].Message);

return new ImmutableEquatableArray<TModel>(models);
}

public CompilationModel ToModel()
Expand All @@ -289,28 +304,27 @@ public CompilationModel ToModel()

try
{
var comparer = new InheritanceComparer();
if (_notificationPublisherImplementationSymbol is null)
throw new Exception("Unexpected state: NotificationPublisherImplementationSymbol is null");

var model = new CompilationModel(
_requestMessages
.OrderBy(m => m.Symbol, comparer)
.Select(x => new RequestMessageModel(
x.Symbol,
x.ResponseSymbol,
x.MessageType,
x.Handler?.ToModel(),
x.WrapperType
))
.ToImmutableEquatableArray(),
_notificationMessages
.OrderBy(m => m.Symbol, comparer)
.Select(x => x.ToModel())
.ToImmutableEquatableArray(),
ToModelsSortedByInheritanceDepth(
_requestMessages,
m => new RequestMessageModel(
m.Symbol,
m.ResponseSymbol,
m.MessageType,
m.Handler?.ToModel(),
m.WrapperType
)
),
ToModelsSortedByInheritanceDepth(_notificationMessages, m => m.ToModel()),
_requestMessageHandlers.Select(x => x.ToModel()).ToImmutableEquatableArray(),
_notificationMessageHandlers.Select(x => x.ToModel()).ToImmutableEquatableArray(),
RequestMessageHandlerWrappers.ToImmutableEquatableArray(),
new NotificationPublisherTypeModel(
_notificationPublisherImplementationSymbol!.GetTypeSymbolFullName(),
_notificationPublisherImplementationSymbol!.Name
_notificationPublisherImplementationSymbol.GetTypeSymbolFullName(),
_notificationPublisherImplementationSymbol.Name
),
HasErrors,
MediatorNamespace,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ public sealed class ImmutableEquatableArray<T> : IEquatable<ImmutableEquatableAr
public T this[int index] => _values[index];
public int Count => _values.Length;

public ImmutableEquatableArray(T[] values) => _values = values;

public ImmutableEquatableArray(IEnumerable<T> values) => _values = values.ToArray();

public bool Equals(ImmutableEquatableArray<T>? other) =>
Expand Down
64 changes: 63 additions & 1 deletion test/Mediator.SourceGenerator.Tests/MessageOrderingTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ await inputCompilation.AssertAndVerify(
Assert.Equal(5, notifications.Count);

var last = notifications[^1];
last.Name.Equals("DomainEvent");
last.Name.Should().Be("DomainEvent");

var index0 = notifications.FindIndex(n => n.Name == "RoundSucceededActually");
var index1 = notifications.FindIndex(n => n.Name == "RoundSucceeded");
Expand All @@ -55,4 +55,66 @@ await inputCompilation.AssertAndVerify(
}
);
}

[Fact]
public async Task Test_Notifications_Ordering_Bigger()
{
var inputCompilation = Fixture.CreateLibrary(
"""
using Mediator;
using System.Threading.Tasks;
using System;
namespace TestCode;
public class Program
{
public static void Main()
{
}
}
public record DomainEvent(DateTimeOffset Timestamp) : INotification;
public record RoundCreated(long Id, DateTimeOffset Timestamp) : DomainEvent(Timestamp);
public record RoundResulted(long Id, long Win, DateTimeOffset Timestamp) : DomainEvent(Timestamp);
public record RoundSucceeded(long Id, DateTimeOffset Timestamp) : DomainEvent(Timestamp);
public record RoundSucceededActually(long Id, string Because, DateTimeOffset Timestamp) : RoundSucceeded(Id, Timestamp);
public record DomainEvent2(DateTimeOffset Timestamp) : INotification;
public record Round2Created(long Id, DateTimeOffset Timestamp) : DomainEvent2(Timestamp);
public record Round2Resulted(long Id, long Win, DateTimeOffset Timestamp) : DomainEvent2(Timestamp);
public record Round2Succeeded(long Id, DateTimeOffset Timestamp) : DomainEvent2(Timestamp);
public record Round2SucceededActually(long Id, string Because, DateTimeOffset Timestamp) : RoundSucceeded(Id, Timestamp);
public record DomainEvent10(DateTimeOffset Timestamp) : INotification;
public record Sound2Created(long Id, DateTimeOffset Timestamp) : DomainEvent10(Timestamp);
public record Sound2Resulted(long Id, long Win, DateTimeOffset Timestamp) : DomainEvent10(Timestamp);
public record Sound2Succeeded(long Id, DateTimeOffset Timestamp) : DomainEvent10(Timestamp);
public record Sound2SucceededActually(long Id, string Because, DateTimeOffset Timestamp) : RoundSucceeded(Id, Timestamp);
public record DomainEvent11(DateTimeOffset Timestamp) : INotification;
public record Sound20Created(long Id, DateTimeOffset Timestamp) : DomainEvent11(Timestamp);
public record Sound20Resulted(long Id, long Win, DateTimeOffset Timestamp) : DomainEvent11(Timestamp);
public record Sound20Succeeded(long Id, DateTimeOffset Timestamp) : DomainEvent11(Timestamp);
public record Sound20SucceededActually(long Id, string Because, DateTimeOffset Timestamp) : Sound20Succeeded(Id, Timestamp);
"""
);

await inputCompilation.AssertAndVerify(
Assertions.CompilesWithoutDiagnostics,
result =>
{
var model = result.Generator.CompilationModel;
Assert.NotNull(model);
var notifications = model.NotificationMessages.ToList();
Assert.Equal(5 * 4, notifications.Count);

Assert.All(notifications.AsEnumerable().Take(4), n => n.Name.Should().EndWith("Actually"));
Assert.All(
notifications.AsEnumerable().Reverse().Take(4),
n => n.Name.Should().StartWith("DomainEvent")
);
}
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -536,9 +536,9 @@ private readonly struct DICache
private readonly global::System.IServiceProvider _sp;

public readonly global::Mediator.INotificationHandler<global::TestCode.RoundSucceededActually>[] Handlers_For_TestCode_RoundSucceededActually;
public readonly global::Mediator.INotificationHandler<global::TestCode.RoundSucceeded>[] Handlers_For_TestCode_RoundSucceeded;
public readonly global::Mediator.INotificationHandler<global::TestCode.RoundResulted>[] Handlers_For_TestCode_RoundResulted;
public readonly global::Mediator.INotificationHandler<global::TestCode.RoundCreated>[] Handlers_For_TestCode_RoundCreated;
public readonly global::Mediator.INotificationHandler<global::TestCode.RoundResulted>[] Handlers_For_TestCode_RoundResulted;
public readonly global::Mediator.INotificationHandler<global::TestCode.RoundSucceeded>[] Handlers_For_TestCode_RoundSucceeded;
public readonly global::Mediator.INotificationHandler<global::TestCode.DomainEvent>[] Handlers_For_TestCode_DomainEvent;

public readonly global::Mediator.ForeachAwaitPublisher InternalNotificationPublisherImpl;
Expand All @@ -562,18 +562,18 @@ public DICache(global::System.IServiceProvider sp, global::Mediator.ContainerMet
global::System.Diagnostics.Debug.Assert(handlers_For_TestCode_RoundSucceededActually is not global::Mediator.INotificationHandler<global::TestCode.RoundSucceededActually>[]);
Handlers_For_TestCode_RoundSucceededActually = handlers_For_TestCode_RoundSucceededActually.ToArray();
}
var handlers_For_TestCode_RoundSucceeded = sp.GetServices<global::Mediator.INotificationHandler<global::TestCode.RoundSucceeded>>();
var handlers_For_TestCode_RoundCreated = sp.GetServices<global::Mediator.INotificationHandler<global::TestCode.RoundCreated>>();
if (containerMetadata.ServicesUnderlyingTypeIsArray)
{
global::System.Diagnostics.Debug.Assert(handlers_For_TestCode_RoundSucceeded is global::Mediator.INotificationHandler<global::TestCode.RoundSucceeded>[]);
Handlers_For_TestCode_RoundSucceeded = global::System.Runtime.CompilerServices.Unsafe.As<global::Mediator.INotificationHandler<global::TestCode.RoundSucceeded>[]>(
handlers_For_TestCode_RoundSucceeded
global::System.Diagnostics.Debug.Assert(handlers_For_TestCode_RoundCreated is global::Mediator.INotificationHandler<global::TestCode.RoundCreated>[]);
Handlers_For_TestCode_RoundCreated = global::System.Runtime.CompilerServices.Unsafe.As<global::Mediator.INotificationHandler<global::TestCode.RoundCreated>[]>(
handlers_For_TestCode_RoundCreated
);
}
else
{
global::System.Diagnostics.Debug.Assert(handlers_For_TestCode_RoundSucceeded is not global::Mediator.INotificationHandler<global::TestCode.RoundSucceeded>[]);
Handlers_For_TestCode_RoundSucceeded = handlers_For_TestCode_RoundSucceeded.ToArray();
global::System.Diagnostics.Debug.Assert(handlers_For_TestCode_RoundCreated is not global::Mediator.INotificationHandler<global::TestCode.RoundCreated>[]);
Handlers_For_TestCode_RoundCreated = handlers_For_TestCode_RoundCreated.ToArray();
}
var handlers_For_TestCode_RoundResulted = sp.GetServices<global::Mediator.INotificationHandler<global::TestCode.RoundResulted>>();
if (containerMetadata.ServicesUnderlyingTypeIsArray)
Expand All @@ -588,18 +588,18 @@ public DICache(global::System.IServiceProvider sp, global::Mediator.ContainerMet
global::System.Diagnostics.Debug.Assert(handlers_For_TestCode_RoundResulted is not global::Mediator.INotificationHandler<global::TestCode.RoundResulted>[]);
Handlers_For_TestCode_RoundResulted = handlers_For_TestCode_RoundResulted.ToArray();
}
var handlers_For_TestCode_RoundCreated = sp.GetServices<global::Mediator.INotificationHandler<global::TestCode.RoundCreated>>();
var handlers_For_TestCode_RoundSucceeded = sp.GetServices<global::Mediator.INotificationHandler<global::TestCode.RoundSucceeded>>();
if (containerMetadata.ServicesUnderlyingTypeIsArray)
{
global::System.Diagnostics.Debug.Assert(handlers_For_TestCode_RoundCreated is global::Mediator.INotificationHandler<global::TestCode.RoundCreated>[]);
Handlers_For_TestCode_RoundCreated = global::System.Runtime.CompilerServices.Unsafe.As<global::Mediator.INotificationHandler<global::TestCode.RoundCreated>[]>(
handlers_For_TestCode_RoundCreated
global::System.Diagnostics.Debug.Assert(handlers_For_TestCode_RoundSucceeded is global::Mediator.INotificationHandler<global::TestCode.RoundSucceeded>[]);
Handlers_For_TestCode_RoundSucceeded = global::System.Runtime.CompilerServices.Unsafe.As<global::Mediator.INotificationHandler<global::TestCode.RoundSucceeded>[]>(
handlers_For_TestCode_RoundSucceeded
);
}
else
{
global::System.Diagnostics.Debug.Assert(handlers_For_TestCode_RoundCreated is not global::Mediator.INotificationHandler<global::TestCode.RoundCreated>[]);
Handlers_For_TestCode_RoundCreated = handlers_For_TestCode_RoundCreated.ToArray();
global::System.Diagnostics.Debug.Assert(handlers_For_TestCode_RoundSucceeded is not global::Mediator.INotificationHandler<global::TestCode.RoundSucceeded>[]);
Handlers_For_TestCode_RoundSucceeded = handlers_For_TestCode_RoundSucceeded.ToArray();
}
var handlers_For_TestCode_DomainEvent = sp.GetServices<global::Mediator.INotificationHandler<global::TestCode.DomainEvent>>();
if (containerMetadata.ServicesUnderlyingTypeIsArray)
Expand Down Expand Up @@ -834,9 +834,9 @@ public DICache(global::System.IServiceProvider sp, global::Mediator.ContainerMet
switch (notification)
{
case global::TestCode.RoundSucceededActually n: return Publish(n, cancellationToken);
case global::TestCode.RoundSucceeded n: return Publish(n, cancellationToken);
case global::TestCode.RoundResulted n: return Publish(n, cancellationToken);
case global::TestCode.RoundCreated n: return Publish(n, cancellationToken);
case global::TestCode.RoundResulted n: return Publish(n, cancellationToken);
case global::TestCode.RoundSucceeded n: return Publish(n, cancellationToken);
case global::TestCode.DomainEvent n: return Publish(n, cancellationToken);
default:
{
Expand Down Expand Up @@ -876,30 +876,30 @@ public DICache(global::System.IServiceProvider sp, global::Mediator.ContainerMet
);
}
/// <summary>
/// Send a notification of type global::TestCode.RoundSucceeded.
/// Send a notification of type global::TestCode.RoundCreated.
/// Throws <see cref="global::System.ArgumentNullException"/> if message is null.
/// Throws <see cref="global::System.AggregateException"/> if handlers throw exception(s).
/// </summary>
/// <param name="notification">Incoming message</param>
/// <param name="cancellationToken">Cancellation token</param>
/// <returns>Awaitable task</returns>
public global::System.Threading.Tasks.ValueTask Publish(
global::TestCode.RoundSucceeded notification,
global::TestCode.RoundCreated notification,
global::System.Threading.CancellationToken cancellationToken = default
)
{
ThrowIfNull(notification, nameof(notification));


var handlers = _diCacheLazy.Value.Handlers_For_TestCode_RoundSucceeded;
var handlers = _diCacheLazy.Value.Handlers_For_TestCode_RoundCreated;

if (handlers.Length == 0)
{
return default;
}
var publisher = _diCacheLazy.Value.InternalNotificationPublisherImpl;
return publisher.Publish(
new global::Mediator.NotificationHandlers<global::TestCode.RoundSucceeded>(handlers, isArray: true),
new global::Mediator.NotificationHandlers<global::TestCode.RoundCreated>(handlers, isArray: true),
notification,
cancellationToken
);
Expand Down Expand Up @@ -934,30 +934,30 @@ public DICache(global::System.IServiceProvider sp, global::Mediator.ContainerMet
);
}
/// <summary>
/// Send a notification of type global::TestCode.RoundCreated.
/// Send a notification of type global::TestCode.RoundSucceeded.
/// Throws <see cref="global::System.ArgumentNullException"/> if message is null.
/// Throws <see cref="global::System.AggregateException"/> if handlers throw exception(s).
/// </summary>
/// <param name="notification">Incoming message</param>
/// <param name="cancellationToken">Cancellation token</param>
/// <returns>Awaitable task</returns>
public global::System.Threading.Tasks.ValueTask Publish(
global::TestCode.RoundCreated notification,
global::TestCode.RoundSucceeded notification,
global::System.Threading.CancellationToken cancellationToken = default
)
{
ThrowIfNull(notification, nameof(notification));


var handlers = _diCacheLazy.Value.Handlers_For_TestCode_RoundCreated;
var handlers = _diCacheLazy.Value.Handlers_For_TestCode_RoundSucceeded;

if (handlers.Length == 0)
{
return default;
}
var publisher = _diCacheLazy.Value.InternalNotificationPublisherImpl;
return publisher.Publish(
new global::Mediator.NotificationHandlers<global::TestCode.RoundCreated>(handlers, isArray: true),
new global::Mediator.NotificationHandlers<global::TestCode.RoundSucceeded>(handlers, isArray: true),
notification,
cancellationToken
);
Expand Down Expand Up @@ -1010,9 +1010,9 @@ public DICache(global::System.IServiceProvider sp, global::Mediator.ContainerMet
switch (notification)
{
case global::TestCode.RoundSucceededActually n: return Publish(n, cancellationToken);
case global::TestCode.RoundSucceeded n: return Publish(n, cancellationToken);
case global::TestCode.RoundResulted n: return Publish(n, cancellationToken);
case global::TestCode.RoundCreated n: return Publish(n, cancellationToken);
case global::TestCode.RoundResulted n: return Publish(n, cancellationToken);
case global::TestCode.RoundSucceeded n: return Publish(n, cancellationToken);
case global::TestCode.DomainEvent n: return Publish(n, cancellationToken);
default:
{
Expand Down
Loading

0 comments on commit f036747

Please sign in to comment.