Skip to content

Commit

Permalink
Fixed state issues when streaming (#5701)
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelstaib authored Jan 20, 2023
1 parent 42a0beb commit a40b0e3
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 29 deletions.
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
using System;
using System.Threading.Tasks;
using HotChocolate.Execution.DependencyInjection;
using HotChocolate.Execution.Processing;
using HotChocolate.Fetching;
using HotChocolate.Language;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.ObjectPool;
using static HotChocolate.Execution.GraphQLRequestFlags;
using static HotChocolate.Execution.ThrowHelper;

Expand All @@ -13,7 +13,7 @@ namespace HotChocolate.Execution.Pipeline;
internal sealed class OperationExecutionMiddleware
{
private readonly RequestDelegate _next;
private readonly ObjectPool<OperationContext> _operationContextPool;
private readonly IFactory<OperationContextOwner> _contextFactory;
private readonly QueryExecutor _queryExecutor;
private readonly SubscriptionExecutor _subscriptionExecutor;
private readonly ITransactionScopeHandler _transactionScopeHandler;
Expand All @@ -22,15 +22,15 @@ internal sealed class OperationExecutionMiddleware

public OperationExecutionMiddleware(
RequestDelegate next,
ObjectPool<OperationContext> operationContextPool,
IFactory<OperationContextOwner> contextFactory,
QueryExecutor queryExecutor,
SubscriptionExecutor subscriptionExecutor,
[SchemaService] ITransactionScopeHandler transactionScopeHandler)
{
_next = next ??
throw new ArgumentNullException(nameof(next));
_operationContextPool = operationContextPool ??
throw new ArgumentNullException(nameof(operationContextPool));
_contextFactory = contextFactory ??
throw new ArgumentNullException(nameof(contextFactory));
_queryExecutor = queryExecutor ??
throw new ArgumentNullException(nameof(queryExecutor));
_subscriptionExecutor = subscriptionExecutor ??
Expand Down Expand Up @@ -76,7 +76,7 @@ private async Task ExecuteOperationAsync(
{
if (operation.Definition.Operation == OperationType.Subscription)
{
// since the context is pooled we need to clone the context for
// since the request context is pooled we need to clone the context for
// long running executions.
var cloned = context.Clone();

Expand All @@ -91,7 +91,8 @@ private async Task ExecuteOperationAsync(
}
else
{
var operationContext = _operationContextPool.Get();
var operationContextOwner = _contextFactory.Create();
var operationContext = operationContextOwner.OperationContext;

try
{
Expand All @@ -102,22 +103,21 @@ await ExecuteQueryOrMutationAsync(
if (operationContext.DeferredScheduler.HasResults &&
context.Result is IQueryResult result)
{
var stream = operationContext.DeferredScheduler.CreateResultStream(result);

context.Result = new ResponseStream(
() => stream,
var results = operationContext.DeferredScheduler.CreateResultStream(result);
var responseStream = new ResponseStream(
() => results,
ExecutionResultKind.DeferredResult);
context.Result.RegisterForCleanup(result);
responseStream.RegisterForCleanup(result);
responseStream.RegisterForCleanup(operationContextOwner);
context.Result = responseStream;
operationContextOwner = null;
}

await _next(context).ConfigureAwait(false);
}
finally
{
if (operationContext is not null)
{
_operationContextPool.Return(operationContext);
}
operationContextOwner?.Dispose();
}
}
}
Expand Down
14 changes: 3 additions & 11 deletions src/HotChocolate/Core/src/Execution/Processing/DeferredStream.cs
Original file line number Diff line number Diff line change
Expand Up @@ -76,30 +76,26 @@ protected override async Task ExecuteAsync(
uint patchId)
{
var operationContext = operationContextOwner.OperationContext;
var aborted = operationContext.RequestAborted;
var error = false;

try
{
_task ??= new StreamExecutionTask(this);
_task.Reset(operationContext, resultId);

operationContext.Scheduler.Register(_task);
await operationContext.Scheduler.ExecuteAsync().ConfigureAwait(false);

// if there is no child task, then there is no more data, so we can complete.
if (_task.ChildTask is null)
{
operationContext.DeferredScheduler.Complete(new(resultId, parentResultId));
operationContextOwner.Dispose();
return;
}

var item = _task.ChildTask.ParentResult[0].Value!;

var result = operationContext
.SetLabel(Label)
.SetPath(operationContext.PathFactory.Append(Path, Index))
.SetPath(operationContext.PathFactory.Append(Path, Index).Clone())
.SetItems(new[] { item })
.SetPatchId(patchId)
.BuildResult();
Expand All @@ -110,19 +106,15 @@ protected override async Task ExecuteAsync(
operationContext.DeferredScheduler.Register(this, patchId);
operationContext.DeferredScheduler.Complete(new(resultId, parentResultId, result));
}
catch(Exception ex)
catch (Exception ex)
{
var builder = operationContext.ErrorHandler.CreateUnexpectedError(ex);
var result = QueryResultBuilder.CreateError(builder.Build());
operationContext.DeferredScheduler.Complete(new(resultId, parentResultId, result));
error = true;
}
finally
{
if (error || aborted.IsCancellationRequested)
{
operationContextOwner.Dispose();
}
operationContextOwner.Dispose();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ public DeferredWorkStateOwner(DeferredWorkState state, ObjectPool<DeferredWorkSt

public void Dispose()
{
if (_disposed == 0 && Interlocked.CompareExchange(ref _disposed, 0, 1) == 0)
if (_disposed == 0 && Interlocked.CompareExchange(ref _disposed, 1, 0) == 0)
{
_statePool.Return(State);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Threading;
using HotChocolate.Execution.DependencyInjection;
Expand Down
1 change: 0 additions & 1 deletion src/HotChocolate/Core/test/Execution.Tests/StreamTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ persons @stream(initialCount: 1) {
Assert.IsType<ResponseStream>(result).MatchSnapshot();
}


[Fact]
public async Task Stream_Label_Set_To_abc()
{
Expand Down

0 comments on commit a40b0e3

Please sign in to comment.