diff --git a/Harmony/Internal/PatchTools.cs b/Harmony/Internal/PatchTools.cs index d8c4f396..76de8dfe 100644 --- a/Harmony/Internal/PatchTools.cs +++ b/Harmony/Internal/PatchTools.cs @@ -86,6 +86,11 @@ internal static MethodBase GetOriginalMethod(this HarmonyMethod attr) return null; return AccessTools.EnumeratorMoveNext(AccessTools.DeclaredMethod(attr.GetDeclaringType(), attr.methodName, attr.argumentTypes)); + + case MethodType.Async: + if (attr.methodName is null) + return null; + return AccessTools.AsyncMoveNext(AccessTools.DeclaredMethod(attr.GetDeclaringType(), attr.methodName, attr.argumentTypes)); } } catch (AmbiguousMatchException ex) diff --git a/Harmony/Public/Attributes.cs b/Harmony/Public/Attributes.cs index 2ef7d2aa..80eb0103 100644 --- a/Harmony/Public/Attributes.cs +++ b/Harmony/Public/Attributes.cs @@ -21,7 +21,8 @@ public enum MethodType StaticConstructor, /// This is an enumerator (, or UniTask coroutine) /// This path will target the method that contains the actual enumerator code - Enumerator + Enumerator, + Async } /// Specifies the type of argument diff --git a/Harmony/Public/Harmony.cs b/Harmony/Public/Harmony.cs index bc6069d3..33798df6 100644 --- a/Harmony/Public/Harmony.cs +++ b/Harmony/Public/Harmony.cs @@ -144,7 +144,7 @@ public void PatchAll(Assembly assembly) { AccessTools.GetTypesFromAssembly(assembly).Do(type => CreateClassProcessor(type).Patch()); } - + /// Searches the given type for Harmony annotation and uses them to create patches /// The type to search /// diff --git a/Harmony/Tools/AccessTools.cs b/Harmony/Tools/AccessTools.cs index 602f44c7..e42c7da1 100644 --- a/Harmony/Tools/AccessTools.cs +++ b/Harmony/Tools/AccessTools.cs @@ -542,6 +542,38 @@ public static MethodInfo EnumeratorMoveNext(MethodBase enumerator) return moveNext; } + private static readonly Type _stateMachineAttributeType = typeof(object).Assembly.GetType("System.Runtime.CompilerServices.AsyncStateMachineAttribute"); + private static readonly MethodInfo _stateMachineTypeGetter = _stateMachineAttributeType?.GetProperty("StateMachineType").GetGetMethod(); + + /// Gets the method of an async method's state machine + /// Async method that creates the state machine internally + /// The internal method of the async state machine or null if no valid async method is detected + public static MethodInfo AsyncMoveNext(MethodBase method) + { + if (method is null) + { + FileLog.Debug("AccessTools.AsyncMoveNext: method is null"); + return null; + } + + var asyncAttribute = method.GetCustomAttributes(false).FirstOrDefault(a => a.GetType() == _stateMachineAttributeType); + if (asyncAttribute == null) + { + FileLog.Debug($"AccessTools.AsyncMoveNext: Could not find AsyncStateMachine for {method.FullDescription()}"); + return null; + } + + var asyncStateMachineType = (Type)_stateMachineTypeGetter.Invoke(method, null); + var asyncMethodBody = DeclaredMethod(asyncStateMachineType, "MoveNext"); + if (asyncMethodBody == null) + { + FileLog.Debug($"AccessTools.AsyncMoveNext: Could not find async method body for {method.FullDescription()}"); + return null; + } + + return asyncMethodBody; + } + /// Gets the names of all method that are declared in a type /// The declaring class/type /// A list of method names