diff --git a/Harmony/Internal/PatchTools.cs b/Harmony/Internal/PatchTools.cs
index d8c4f396..76de8dfe 100644
--- a/Harmony/Internal/PatchTools.cs
+++ b/Harmony/Internal/PatchTools.cs
@@ -86,6 +86,11 @@ internal static MethodBase GetOriginalMethod(this HarmonyMethod attr)
return null;
return AccessTools.EnumeratorMoveNext(AccessTools.DeclaredMethod(attr.GetDeclaringType(),
attr.methodName, attr.argumentTypes));
+
+ case MethodType.Async:
+ if (attr.methodName is null)
+ return null;
+ return AccessTools.AsyncMoveNext(AccessTools.DeclaredMethod(attr.GetDeclaringType(), attr.methodName, attr.argumentTypes));
}
}
catch (AmbiguousMatchException ex)
diff --git a/Harmony/Public/Attributes.cs b/Harmony/Public/Attributes.cs
index 2ef7d2aa..80eb0103 100644
--- a/Harmony/Public/Attributes.cs
+++ b/Harmony/Public/Attributes.cs
@@ -21,7 +21,8 @@ public enum MethodType
StaticConstructor,
/// This is an enumerator (, or UniTask coroutine)
/// This path will target the method that contains the actual enumerator code
- Enumerator
+ Enumerator,
+ Async
}
/// Specifies the type of argument
diff --git a/Harmony/Public/Harmony.cs b/Harmony/Public/Harmony.cs
index bc6069d3..33798df6 100644
--- a/Harmony/Public/Harmony.cs
+++ b/Harmony/Public/Harmony.cs
@@ -144,7 +144,7 @@ public void PatchAll(Assembly assembly)
{
AccessTools.GetTypesFromAssembly(assembly).Do(type => CreateClassProcessor(type).Patch());
}
-
+
/// Searches the given type for Harmony annotation and uses them to create patches
/// The type to search
///
diff --git a/Harmony/Tools/AccessTools.cs b/Harmony/Tools/AccessTools.cs
index 602f44c7..e42c7da1 100644
--- a/Harmony/Tools/AccessTools.cs
+++ b/Harmony/Tools/AccessTools.cs
@@ -542,6 +542,38 @@ public static MethodInfo EnumeratorMoveNext(MethodBase enumerator)
return moveNext;
}
+ private static readonly Type _stateMachineAttributeType = typeof(object).Assembly.GetType("System.Runtime.CompilerServices.AsyncStateMachineAttribute");
+ private static readonly MethodInfo _stateMachineTypeGetter = _stateMachineAttributeType?.GetProperty("StateMachineType").GetGetMethod();
+
+ /// Gets the method of an async method's state machine
+ /// Async method that creates the state machine internally
+ /// The internal method of the async state machine or null if no valid async method is detected
+ public static MethodInfo AsyncMoveNext(MethodBase method)
+ {
+ if (method is null)
+ {
+ FileLog.Debug("AccessTools.AsyncMoveNext: method is null");
+ return null;
+ }
+
+ var asyncAttribute = method.GetCustomAttributes(false).FirstOrDefault(a => a.GetType() == _stateMachineAttributeType);
+ if (asyncAttribute == null)
+ {
+ FileLog.Debug($"AccessTools.AsyncMoveNext: Could not find AsyncStateMachine for {method.FullDescription()}");
+ return null;
+ }
+
+ var asyncStateMachineType = (Type)_stateMachineTypeGetter.Invoke(method, null);
+ var asyncMethodBody = DeclaredMethod(asyncStateMachineType, "MoveNext");
+ if (asyncMethodBody == null)
+ {
+ FileLog.Debug($"AccessTools.AsyncMoveNext: Could not find async method body for {method.FullDescription()}");
+ return null;
+ }
+
+ return asyncMethodBody;
+ }
+
/// Gets the names of all method that are declared in a type
/// The declaring class/type
/// A list of method names