Skip to content

Commit

Permalink
Added Support for Subscription Arguments on the Stream Factory (#5691)
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelstaib authored Jan 16, 2023
1 parent 8848df5 commit f8b573e
Show file tree
Hide file tree
Showing 11 changed files with 732 additions and 523 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -238,12 +238,22 @@ public SubscribeResolverDelegate CompileSubscribe(

if (member is MethodInfo method)
{
var parameters = method.GetParameters();
var owner = CreateResolverOwner(_context, sourceType, resolverType);
var parameterExpr = CreateParameters(_context, parameters, _empty);
Expression subscribeResolver = Call(owner, method, parameterExpr);
subscribeResolver = EnsureSubscribeResult(subscribeResolver, method.ReturnType);
return Lambda<SubscribeResolverDelegate>(subscribeResolver, _context).Compile();
if (method.IsStatic)
{
var parameterExpr = CreateParameters(_context, method.GetParameters(), _empty);
Expression subscribeResolver = Call(method, parameterExpr);
subscribeResolver = EnsureSubscribeResult(subscribeResolver, method.ReturnType);
return Lambda<SubscribeResolverDelegate>(subscribeResolver, _context).Compile();
}
else
{
var parameters = method.GetParameters();
var owner = CreateResolverOwner(_context, sourceType, resolverType);
var parameterExpr = CreateParameters(_context, parameters, _empty);
Expression subscribeResolver = Call(owner, method, parameterExpr);
subscribeResolver = EnsureSubscribeResult(subscribeResolver, method.ReturnType);
return Lambda<SubscribeResolverDelegate>(subscribeResolver, _context).Compile();
}
}

throw new ArgumentException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,19 +62,23 @@ public override void OnConfigure(
}
else
{

descriptor.Extend().OnBeforeCreate(d =>
{
var subscribeResolver = member.DeclaringType?.GetMethod(With!, Public | Instance);

if (subscribeResolver is null)
descriptor.Extend().OnBeforeCreate(
d =>
{
throw SubscribeAttribute_SubscribeResolverNotFound(member, With);
}

d.SubscribeResolver = context.ResolverCompiler.CompileSubscribe(
subscribeResolver, d.SourceType!, d.ResolverType);
});
var subscribeResolver = member.DeclaringType?.GetMethod(
With!,
Public | NonPublic | Instance | Static);

if (subscribeResolver is null)
{
throw SubscribeAttribute_SubscribeResolverNotFound(member, With);
}

d.SubscribeResolver = context.ResolverCompiler.CompileSubscribe(
subscribeResolver,
d.SourceType!,
d.ResolverType);
});
}
}

Expand All @@ -88,7 +92,9 @@ private static string ResolveTopicString(MethodInfo method)
return method.Name;
}

private static void SubscribeFactory<TMessage>(ObjectFieldDefinition fieldDef, string topicString)
private static void SubscribeFactory<TMessage>(
ObjectFieldDefinition fieldDef,
string topicString)
{
var arg = false;

Expand Down Expand Up @@ -125,10 +131,10 @@ private static SubscribeResolverDelegate CreateSubscribeResolver<TMessage>(
var ct = ctx.RequestAborted;
var receiver = ctx.Service<ITopicEventReceiver>();
return await receiver.SubscribeAsync<TMessage>(
topicString,
null,
null,
ct)
topicString,
null,
null,
ct)
.ConfigureAwait(false);
};
}
Expand All @@ -154,10 +160,10 @@ private static SubscribeResolverDelegate CreateArgumentSubscribeResolver<TMessag
// last we subscribe with the topic string.
var receiver = ctx.Service<ITopicEventReceiver>();
return await receiver.SubscribeAsync<TMessage>(
topicString,
null,
null,
ct)
topicString,
null,
null,
ct)
.ConfigureAwait(false);
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ public ObjectFieldDefinition(
/// </summary>
public Type? ResultType { get; set; }

/// <summary>
/// The member name that represents the event stream factory.
/// </summary>
public string? SubscribeWith { get; set; }

/// <summary>
/// The delegate that represents the resolver.
/// </summary>
Expand Down Expand Up @@ -315,6 +320,7 @@ internal void CopyTo(ObjectFieldDefinition target)
target.IsIntrospectionField = IsIntrospectionField;
target.IsParallelExecutable = IsParallelExecutable;
target.HasStreamResult = HasStreamResult;
target.SubscribeWith = SubscribeWith;
}

internal void MergeInto(ObjectFieldDefinition target)
Expand Down Expand Up @@ -396,6 +402,11 @@ internal void MergeInto(ObjectFieldDefinition target)
{
target.SubscribeResolver = SubscribeResolver;
}

if (SubscribeWith is not null)
{
target.SubscribeWith = SubscribeWith;
}
}

private static void CleanMiddlewareDefinitions<T>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,6 @@ protected internal void MergeInto(ObjectTypeDefinition target)
newField.SourceType = target.RuntimeType;

SetResolverMember(newField, targetField);

target.Fields.Add(newField);
}
else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ public class InterfaceFieldDescriptor
: OutputFieldDescriptorBase<InterfaceFieldDefinition>
, IInterfaceFieldDescriptor
{
private ParameterInfo[] _parameterInfos = Array.Empty<ParameterInfo>();
private bool _argumentsInitialized;

protected internal InterfaceFieldDescriptor(
Expand Down Expand Up @@ -48,7 +49,8 @@ protected internal InterfaceFieldDescriptor(

if (member is MethodInfo m)
{
Parameters = m.GetParameters().ToDictionary(t => t.Name, StringComparer.Ordinal);
_parameterInfos = m.GetParameters();
Parameters = _parameterInfos.ToDictionary(t => t.Name, StringComparer.Ordinal);
}
}

Expand All @@ -75,6 +77,7 @@ private void CompleteArguments(InterfaceFieldDefinition definition)
Context,
definition.Arguments,
definition.Member,
_parameterInfos,
definition.GetParameterExpressionBuilders());
_argumentsInitialized = true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
using HotChocolate.Types.Descriptors.Definitions;
using HotChocolate.Types.Helpers;
using HotChocolate.Utilities;
using static System.Reflection.BindingFlags;
using static HotChocolate.Execution.ExecutionStrategy;

#nullable enable
Expand Down Expand Up @@ -157,17 +158,58 @@ protected override void OnCreateDefinition(ObjectFieldDefinition definition)

private void CompleteArguments(ObjectFieldDefinition definition)
{
if (!_argumentsInitialized && Parameters.Count > 0)
if (!_argumentsInitialized)
{
Context.ResolverCompiler.ApplyConfiguration(
_parameterInfos,
this);

FieldDescriptorUtilities.DiscoverArguments(
Context,
definition.Arguments,
definition.Member,
definition.GetParameterExpressionBuilders());
if (definition.SubscribeWith is not null)
{
var ownerType = definition.ResolverType ?? definition.SourceType;

if (ownerType is not null)
{
var subscribeMember = ownerType.GetMember(
definition.SubscribeWith,
Public | NonPublic | Instance | Static)[0];

if (subscribeMember is MethodInfo subscribeMethod)
{
var subscribeParameters = subscribeMethod.GetParameters();
var parameterLength = _parameterInfos.Length + subscribeParameters.Length;
var parameters = new ParameterInfo[parameterLength];

_parameterInfos.CopyTo(parameters, 0);
subscribeParameters.CopyTo(parameters, _parameterInfos.Length);
_parameterInfos = parameters;

var parameterLookup = Parameters.ToDictionary(
t => t.Key,
t => t.Value,
StringComparer.Ordinal);
Parameters = parameterLookup;

foreach (var parameter in subscribeParameters)
{
if (!parameterLookup.ContainsKey(parameter.Name!))
{
parameterLookup.Add(parameter.Name!, parameter);
}
}
}
}
}

if (Parameters.Count > 0)
{
Context.ResolverCompiler.ApplyConfiguration(
_parameterInfos,
this);

FieldDescriptorUtilities.DiscoverArguments(
Context,
definition.Arguments,
definition.Member,
_parameterInfos,
definition.GetParameterExpressionBuilders());
}

_argumentsInitialized = true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ namespace HotChocolate.Types.Descriptors;

public class ObjectTypeDescriptor
: DescriptorBase<ObjectTypeDefinition>
, IObjectTypeDescriptor
, IObjectTypeDescriptor
{
private readonly List<ObjectFieldDescriptor> _fields = new();

Expand Down Expand Up @@ -130,7 +130,10 @@ protected void InferFieldsFromFieldBindingType(
IDictionary<string, ObjectFieldDefinition> fields,
ISet<MemberInfo> handledMembers)
{
HashSet<string>? subscribeResolver = null;
var skip = false;
HashSet<string>? subscribeRes = null;
Dictionary<MemberInfo, string>? subscribeResLook = null;


if (Definition.Fields.IsImplicitBinding() &&
Definition.FieldBindingType is not null)
Expand All @@ -148,14 +151,20 @@ protected void InferFieldsFromFieldBindingType(

if (handledMembers.Add(member) &&
!fields.ContainsKey(name) &&
IncludeField(ref subscribeResolver, members, member))
IncludeField(ref skip, ref subscribeRes, ref subscribeResLook, members, member))
{
var descriptor = ObjectFieldDescriptor.New(
Context,
member,
Definition.RuntimeType,
type);

if (subscribeResLook is not null &&
subscribeResLook.TryGetValue(member, out var with))
{
descriptor.Definition.SubscribeWith = with;
}

if (isExtension && inspector.IsMemberIgnored(member))
{
descriptor.Ignore();
Expand All @@ -173,35 +182,38 @@ protected void InferFieldsFromFieldBindingType(
}

static bool IncludeField(
ref bool skip,
ref HashSet<string>? subscribeResolver,
ref Dictionary<MemberInfo, string>? subscribeResolverLookup,
ReadOnlySpan<MemberInfo> allMembers,
MemberInfo current)
{
if (subscribeResolver is null)
// if there is now with declared we can include all members.
if (skip)
{
subscribeResolver = new HashSet<string>();
return true;
}

if (subscribeResolver is null)
{
foreach (var member in allMembers)
{
HandlePossibleSubscribeMember(subscribeResolver, member);
if (member.IsDefined(typeof(SubscribeAttribute)) &&
member.GetCustomAttribute<SubscribeAttribute>() is { With: not null } a)
{
subscribeResolver ??= new HashSet<string>();
subscribeResolverLookup ??= new Dictionary<MemberInfo, string>();
subscribeResolver.Add(a.With);
subscribeResolverLookup.Add(member, a.With);
}
}

skip = subscribeResolver is null;
}

return !subscribeResolver.Contains(current.Name);
return !subscribeResolver?.Contains(current.Name) ?? true;
}

static void HandlePossibleSubscribeMember(
HashSet<string> subscribeResolver,
MemberInfo member)
{
if (member.IsDefined(typeof(SubscribeAttribute)))
{
if (member.GetCustomAttribute<SubscribeAttribute>() is { With: not null } attr)
{
subscribeResolver.Add(attr.With);
}
}
}
}

protected virtual void OnCompleteFields(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,15 @@ public static void DiscoverArguments(
IDescriptorContext context,
ICollection<ArgumentDefinition> arguments,
MemberInfo? member,
ParameterInfo[] parameters,
IReadOnlyList<IParameterExpressionBuilder>? parameterExpressionBuilders)
{
if (arguments is null)
{
throw new ArgumentNullException(nameof(arguments));
}

if (member is MethodInfo method)
if (member is MethodInfo)
{
var processedNames = TypeMemHelper.RentNameSet();

Expand All @@ -122,7 +123,7 @@ public static void DiscoverArguments(

foreach (var parameter in
context.ResolverCompiler.GetArgumentParameters(
method.GetParameters(),
parameters,
parameterExpressionBuilders))
{
var argumentDefinition =
Expand Down
Loading

0 comments on commit f8b573e

Please sign in to comment.