Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve: implementation of ToAwaitable() #88

Merged
merged 2 commits into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
using System;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Threading;
using UnityEngine;

namespace LitMotion
{
#if UNITY_2023_1_OR_NEWER

internal sealed class AwaitableMotionConfiguredSource
{
[StructLayout(LayoutKind.Auto)]
public struct Pool
{
static readonly int MaxPoolSize = int.MaxValue;

int gate;
int size;
AwaitableMotionConfiguredSource root;

public readonly int Size => size;

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public bool TryPop(out AwaitableMotionConfiguredSource result)
{
if (Interlocked.CompareExchange(ref gate, 1, 0) == 0)
{
var v = root;
if (v != null)
{
ref var nextNode = ref v.NextNode;
root = nextNode;
nextNode = null;
size--;
result = v;
Volatile.Write(ref gate, 0);
return true;
}

Volatile.Write(ref gate, 0);
}
result = default;
return false;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public bool TryPush(AwaitableMotionConfiguredSource item)
{
if (Interlocked.CompareExchange(ref gate, 1, 0) == 0)
{
if (size < MaxPoolSize)
{
item.NextNode = root;
root = item;
size++;
Volatile.Write(ref gate, 0);
return true;
}
else
{
Volatile.Write(ref gate, 0);
}
}
return false;
}
}

static Pool pool;
AwaitableMotionConfiguredSource nextNode;
public ref AwaitableMotionConfiguredSource NextNode => ref nextNode;

public static AwaitableMotionConfiguredSource CompletedSource
{
get
{
if (completedSource == null)
{
completedSource = new();
completedSource.core.SetResult();
}
return completedSource;
}
}
static AwaitableMotionConfiguredSource completedSource;

public static AwaitableMotionConfiguredSource CanceledSource
{
get
{
if (canceledSource == null)
{
canceledSource = new();
canceledSource.core.SetCanceled();
}
return canceledSource;
}
}
static AwaitableMotionConfiguredSource canceledSource;

readonly Action onCancelCallbackDelegate;
readonly Action onCompleteCallbackDelegate;

MotionHandle motionHandle;
CancellationToken cancellationToken;
CancellationTokenRegistration cancellationRegistration;

Action originalCompleteAction;
Action originalCancelAction;
readonly AwaitableCompletionSource core = new();

public Awaitable Awaitable => core.Awaitable;

AwaitableMotionConfiguredSource()
{
onCancelCallbackDelegate = new(OnCancelCallbackDelegate);
onCompleteCallbackDelegate = new(OnCompleteCallbackDelegate);
}

public static AwaitableMotionConfiguredSource Create(MotionHandle motionHandle, CancellationToken cancellationToken)
{
if (cancellationToken.IsCancellationRequested)
{
motionHandle.Cancel();
return CanceledSource;
}

if (!pool.TryPop(out var result))
{
result = new AwaitableMotionConfiguredSource();
}

result.motionHandle = motionHandle;
result.cancellationToken = cancellationToken;

var callbacks = MotionStorageManager.GetMotionCallbacks(motionHandle);
result.originalCancelAction = callbacks.OnCancelAction;
result.originalCompleteAction = callbacks.OnCompleteAction;
callbacks.OnCancelAction = result.onCancelCallbackDelegate;
callbacks.OnCompleteAction = result.onCompleteCallbackDelegate;
MotionStorageManager.SetMotionCallbacks(motionHandle, callbacks);

if (result.originalCancelAction == result.onCancelCallbackDelegate)
{
result.originalCancelAction = null;
}
if (result.originalCompleteAction == result.onCompleteCallbackDelegate)
{
result.originalCompleteAction = null;
}

if (cancellationToken.CanBeCanceled)
{
result.cancellationRegistration = cancellationToken.Register(static x =>
{
var source = (AwaitableMotionConfiguredSource)x;
var motionHandle = source.motionHandle;
if (motionHandle.IsActive())
{
motionHandle.Cancel();
}
else
{
source.core.SetCanceled();
source.TryReturn();
}
}, result);
}

return result;
}

void OnCompleteCallbackDelegate()
{
if (cancellationToken.IsCancellationRequested)
{
core.SetCanceled();
}
else
{
originalCompleteAction?.Invoke();
core.SetResult();
}

TryReturn();
}

void OnCancelCallbackDelegate()
{
originalCancelAction?.Invoke();
core.SetCanceled();
TryReturn();
}

bool TryReturn()
{
core.Reset();
cancellationRegistration.Dispose();

RestoreOriginalCallback();

motionHandle = default;
cancellationToken = default;
originalCompleteAction = default;
originalCancelAction = default;
return pool.TryPush(this);
}

void RestoreOriginalCallback()
{
if (!motionHandle.IsActive()) return;
var callbacks = MotionStorageManager.GetMotionCallbacks(motionHandle);
callbacks.OnCancelAction = originalCancelAction;
callbacks.OnCompleteAction = originalCompleteAction;
MotionStorageManager.SetMotionCallbacks(motionHandle, callbacks);
}
}

#endif
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 5 additions & 5 deletions src/LitMotion/Assets/LitMotion/Runtime/MotionHandleExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
using System;
using System.Collections;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Threading;
using UnityEngine;

Expand Down Expand Up @@ -120,12 +122,10 @@ public static IEnumerator ToYieldInteraction(this MotionHandle handle)
}

#if UNITY_2023_1_OR_NEWER
public static async Awaitable ToAwaitable(this MotionHandle handle, CancellationToken cancellationToken = default)
public static Awaitable ToAwaitable(this MotionHandle handle, CancellationToken cancellationToken = default)
{
while (handle.IsActive())
{
await Awaitable.NextFrameAsync(cancellationToken);
}
if (!handle.IsActive()) return AwaitableMotionConfiguredSource.CompletedSource.Awaitable;
return AwaitableMotionConfiguredSource.Create(handle, cancellationToken).Awaitable;
}
#endif

Expand Down
108 changes: 108 additions & 0 deletions src/LitMotion/Assets/LitMotion/Tests/Runtime/AwaitableTest.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
#if UNITY_2023_1_OR_NEWER
using System;
using System.Threading;
using System.Threading.Tasks;
using NUnit.Framework;
using UnityEngine;
using UnityEngine.TestTools;

namespace LitMotion.Tests.Runtime
{
Expand All @@ -14,6 +18,110 @@ public async Task Test_Awaitable()
await handle.ToAwaitable();
Assert.That(value, Is.EqualTo(10f));
}

[Test]
public async Task Test_Awaitable_Completed()
{
await default(MotionHandle).ToAwaitable();
}

[Test]
public async Task Test_Awaitable_CancelAwait()
{
var canceled = false;

var source = new CancellationTokenSource();
source.CancelAfter(500);

var handle = LMotion.Create(0f, 10f, 1f)
.WithOnCancel(() => canceled = true)
.RunWithoutBinding();
try
{
await handle.ToAwaitable(source.Token);
}
catch (OperationCanceledException)
{
Assert.IsTrue(canceled);
return;
}
Assert.Fail();
}

[Test]
public async Task Test_Awaitable_WithCanceledToken()
{
var canceled = false;

var source = new CancellationTokenSource();
source.Cancel();

var handle = LMotion.Create(0f, 10f, 1f)
.WithOnCancel(() => canceled = true)
.RunWithoutBinding();
try
{
await handle.ToAwaitable(source.Token);
}
catch (OperationCanceledException)
{
Assert.IsTrue(canceled);
return;
}
Assert.Fail();
}

[Test]
public async Task Test_Awaitable_CancelWhileAwait()
{
var canceled = false;

var handle = LMotion.Create(0f, 10f, 1f)
.WithOnCancel(() => canceled = true)
.RunWithoutBinding();

_ = DelayedCall(0.2f, () => handle.Cancel());

try
{
await handle.ToAwaitable();
}
catch (OperationCanceledException)
{
Assert.IsTrue(canceled);
return;
}
Assert.Fail();
}

[Test]
public async Task Test_CancelWhileAwait_WithCancelOnError()
{
LogAssert.ignoreFailingMessages = true;

var handle = LMotion.Create(0f, 10f, 1f)
.WithCancelOnError()
.Bind(x =>
{
if (x > 5f) throw new Exception("Test");
});

try
{
await handle.ToAwaitable();
}
catch (OperationCanceledException)
{
return;
}
Assert.Fail();
}

async Awaitable DelayedCall(float delay, Action action)
{
await Awaitable.WaitForSecondsAsync(delay);
action.Invoke();
}
}
}
#endif