Skip to content

Commit

Permalink
[MTGOSDK/Core] Add SyncThread parallelism
Browse files Browse the repository at this point in the history
Allows for efficiently queuing and dispatching multiple concurrent requests to ensure I/O intensive callbacks do not stall the event loop.
  • Loading branch information
Qonfused committed Jan 22, 2025
1 parent 199c53a commit 9d1a81a
Show file tree
Hide file tree
Showing 8 changed files with 242 additions and 196 deletions.
73 changes: 13 additions & 60 deletions MTGOSDK/lib/ScubaDiver/src/Diver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,11 @@
using System.Net;
using System.Text;
using System.Threading;
using System.Threading.Tasks;

using Newtonsoft.Json;

using MTGOSDK.Core.Compiler.Snapshot;
using MTGOSDK.Core.Logging;
using MTGOSDK.Core.Remoting.Interop;
using MTGOSDK.Core.Remoting.Interop.Interactions;
using MTGOSDK.Core.Remoting.Interop.Interactions.Callbacks;

Expand All @@ -36,9 +34,8 @@ public partial class Diver : IDisposable
private readonly Dictionary<string, Func<HttpListenerRequest, string>> _responseBodyCreators;

// Callbacks Endpoint of the Controller process
private bool _monitorEndpoints = true;
private readonly ConcurrentDictionary<int, RegisteredEventHandlerInfo> _remoteEventHandler;
private readonly ConcurrentDictionary<int, RegisteredMethodHookInfo> _remoteHooks;
private readonly ConcurrentDictionary<int, RegisteredMethodHookInfo> _remoteHooks;

private readonly ManualResetEvent _stayAlive = new(true);

Expand Down Expand Up @@ -84,66 +81,22 @@ public void Start(ushort listenPort)
HttpListener listener = new();
string listeningUrl = $"http://127.0.0.1:{listenPort}/";
listener.Prefixes.Add(listeningUrl);

// Set timeout
var manager = listener.TimeoutManager;
manager.IdleConnection = TimeSpan.FromSeconds(5);
listener.Start();
Log.Debug($"[Diver] Listening on {listeningUrl}...");

Task endpointsMonitor = Task.Run(CallbacksEndpointsMonitor);
Dispatcher(listener);
Log.Debug("[Diver] Stopping Callback Endpoints Monitor");
_monitorEndpoints = false;
try { endpointsMonitor.Wait(); } catch { }

Log.Debug("[Diver] Closing listener");
listener.Stop();
listener.Close();
Log.Debug("[Diver] Closing ClrMD runtime and snapshot");

Log.Debug("[Diver] Unpinning objects");
Log.Debug("[Diver] Unpinning finished");
Log.Debug("[Diver] Closing ClrMD runtime and snapshot");
this.Dispose();

Log.Debug("[Diver] Dispatcher returned, Start is complete.");
}

private void CallbacksEndpointsMonitor()
{
while (_monitorEndpoints)
{
Thread.Sleep(TimeSpan.FromSeconds(1));
IPEndPoint endpoint;
foreach (var registeredMethodHookInfo in _remoteHooks)
{
endpoint = registeredMethodHookInfo.Value.Endpoint;
ReverseCommunicator reverseCommunicator = new(endpoint);
var token = registeredMethodHookInfo.Key;
Log.Debug($"[Diver] Checking if callback client at {endpoint} is alive. Token = {token}. Type = Method Hook");
bool alive = reverseCommunicator.CheckIfAlive();
Log.Debug($"[Diver] Callback client at {endpoint} (Token = {token}) is alive = {alive}");
if (!alive)
{
Log.Debug($"[Diver] Dead Callback client at {endpoint} (Token = {token}) DROPPED!");
if (_remoteHooks.TryRemove(token, out RegisteredMethodHookInfo rmhi))
{
HarmonyWrapper.Instance.RemovePrefix(rmhi.OriginalHookedMethod);
}
}
}
foreach (var registeredEventHandlerInfo in _remoteEventHandler)
{
endpoint = registeredEventHandlerInfo.Value.Endpoint;
ReverseCommunicator reverseCommunicator = new(endpoint);
Log.Debug($"[Diver] Checking if callback client at {endpoint} is alive. Token = {registeredEventHandlerInfo.Key}. Type = Event");
bool alive = reverseCommunicator.CheckIfAlive();
Log.Debug($"[Diver] Callback client at {endpoint} (Token = {registeredEventHandlerInfo.Key}) is alive = {alive}");
if (!alive)
{
Log.Debug($"[Diver] Dead Callback client at {endpoint} (Token = {registeredEventHandlerInfo.Key}) DROPPED!");
_remoteEventHandler.TryRemove(registeredEventHandlerInfo.Key, out _);
}
}
}
Log.Debug("[Diver] Exiting");
}

public string QuickError(string error, string stackTrace = null)
Expand Down Expand Up @@ -198,9 +151,9 @@ private void Dispatcher(HttpListener listener)
{
void ListenerCallback(IAsyncResult result)
{
try
{
HarmonyWrapper.Instance.UnregisterFrameworkThread(Thread.CurrentThread.ManagedThreadId);
// try
// {
// HarmonyWrapper.Instance.UnregisterFrameworkThread(Thread.CurrentThread.ManagedThreadId);

HttpListener listener = (HttpListener)result.AsyncState;
HttpListenerContext context;
Expand Down Expand Up @@ -231,11 +184,11 @@ void ListenerCallback(IAsyncResult result)
{
Log.Debug("[Diver] Task faulted! Exception: " + e.ToString());
}
}
finally
{
HarmonyWrapper.Instance.UnregisterFrameworkThread(Thread.CurrentThread.ManagedThreadId);
}
// }
// finally
// {
// HarmonyWrapper.Instance.UnregisterFrameworkThread(Thread.CurrentThread.ManagedThreadId);
// }
}
IAsyncResult asyncOperation = listener.BeginGetContext(ListenerCallback, listener);

Expand Down
33 changes: 20 additions & 13 deletions MTGOSDK/lib/ScubaDiver/src/DllEntry.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,24 @@ namespace ScubaDiver;

public class DllEntry
{
private static void UseAssemblyLoadHook()
private static Assembly LoadAssembly(string name)
{
// Add a hook to load assemblies next to the current assembly's filepath.
AppDomain.CurrentDomain.AssemblyResolve += (sender, args) =>
{
string assemblyPath = Path.Combine(
Path.GetDirectoryName(typeof(DllEntry).Assembly.Location),
new AssemblyName(args.Name).Name + ".dll"
);
string assemblyPath = Path.Combine(
Path.GetDirectoryName(typeof(DllEntry).Assembly.Location),
name + ".dll"
);

if (File.Exists(assemblyPath))
return Assembly.LoadFrom(assemblyPath);
if (File.Exists(assemblyPath))
return Assembly.LoadFrom(assemblyPath);

return null;
};
return null;
}

private static void UseAssemblyLoadHook()
{
// Add a hook to load assemblies next to the current assembly's filepath.
AppDomain.CurrentDomain.AssemblyResolve += (s, e) =>
LoadAssembly(new AssemblyName(e.Name).Name);
}

private static void DiverHost(object pwzArgument)
Expand Down Expand Up @@ -83,7 +86,11 @@ public static int EntryPoint(string pwzArgument)
// The bootstrapper is expecting to call a C# function with this signature,
// so we use it to start a new thread to host the diver in it's own thread.
ParameterizedThreadStart func = DiverHost;
Thread diverHostThread = new(func);
Thread diverHostThread = new(func)
{
IsBackground = true,
Name = "DiverHostThread",
};
diverHostThread.Start(pwzArgument);

// Block the thread until the diver has exited.
Expand Down
35 changes: 10 additions & 25 deletions MTGOSDK/lib/ScubaDiver/src/Hooking/HarmonyWrapper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
using System.Collections.Concurrent;
using System.Linq;
using System.Reflection;
using System.Threading.Tasks;

using HarmonyLib;

using MTGOSDK.Core;
using MTGOSDK.Core.Remoting.Hooking;
using System.Threading.Tasks;


namespace ScubaDiver.Hooking;
Expand Down Expand Up @@ -54,7 +55,7 @@ public override string ToString() =>

private HarmonyWrapper()
{
_harmony = new Harmony("com.videre.mtgoinjector");
_harmony = new Harmony("com.videre.mtgosdk");
_psHooks = new Dictionary<string, MethodInfo>();
var methods = typeof(HarmonyWrapper).GetMethods((BindingFlags)0xffffff);
foreach (MethodInfo method in methods)
Expand Down Expand Up @@ -209,30 +210,14 @@ public void RemovePrefix(MethodBase target)

private static void SinglePrefixHook(MethodBase __originalMethod, object __instance, params object[] args)
{
new Task(() =>
string uniqueId = __originalMethod.DeclaringType.FullName + ":"
+ __originalMethod.Name;
if (_actualHooks.TryGetValue(uniqueId, out HookCallback funcHook))
{
// // Avoid patching a ScubaDiver framework method to avoid infinite recursion.
// SmartLocksDict<MethodBase>.AcquireResults res = _locksDict.Acquire(__originalMethod);
// if(res == SmartLocksDict<MethodBase>.AcquireResults.AlreadyAcquireByCurrentThread ||
// res == SmartLocksDict<MethodBase>.AcquireResults.ThreadNotAllowedToLock)
// {
// return; // Don't skip original
// }

try
{
string uniqueId = __originalMethod.DeclaringType.FullName + ":" + __originalMethod.Name;
if (_actualHooks.TryGetValue(uniqueId, out HookCallback funcHook))
{
// Logger.Debug($"[HarmonyWrapper][SinglePrefixHook] Invoking hook for method {uniqueId}");
funcHook(__instance, args);
}
}
finally
{
// _locksDict.Release(__originalMethod);
}
}).Start();
Action callback = () => funcHook(__instance, args);
// Task.Run(callback);
SyncThread.Enqueue(callback);
}
}

#pragma warning disable IDE0051 // Remove unused private members
Expand Down
99 changes: 99 additions & 0 deletions MTGOSDK/src/Core/ConcurrentTaskScheduler.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/** @file
Copyright (c) 2025, Cory Bennett. All rights reserved.
SPDX-License-Identifier: Apache-2.0
**/

using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;


namespace MTGOSDK.Core;

/// <summary>
/// A task scheduler that ensures a limited degree of concurrency.
/// </summary>
public class ConcurrentTaskScheduler(
int maxDegreeOfParallelism,
CancellationToken cancellationToken) : TaskScheduler
{
private readonly LinkedList<Task> _tasks = new();
private readonly SemaphoreSlim _semaphore = new(maxDegreeOfParallelism);

protected override IEnumerable<Task> GetScheduledTasks()
{
lock (_tasks)
{
return _tasks.ToArray();
}
}

protected override void QueueTask(Task task)
{
lock (_tasks)
{
_tasks.AddLast(task);
NotifyThreadPoolOfPendingWork();
}
}

private void NotifyThreadPoolOfPendingWork()
{
// Skip propagating the CAS restrictions of the current thread to quickly
// dispatch the work item to another thread in the thread pool.
ThreadPool.UnsafeQueueUserWorkItem(async _ =>
{
try
{
await _semaphore.WaitAsync(cancellationToken);

while (true)
{
Task item;
lock (_tasks)
{
if (_tasks.Count == 0)
break;

item = _tasks.First.Value;
_tasks.RemoveFirst();
}
base.TryExecuteTask(item);
}
}
finally
{
_semaphore.Release();
}
}, null);
}

protected override bool TryExecuteTaskInline(
Task task,
bool taskWasPreviouslyQueued)
{
if (taskWasPreviouslyQueued)
{
if (TryDequeue(task))
{
return base.TryExecuteTask(task);
}
else
{
return false;
}
}
else
{
return base.TryExecuteTask(task);
}
}

protected override bool TryDequeue(Task task)
{
lock (_tasks)
{
return _tasks.Remove(task);
}
}
}
5 changes: 3 additions & 2 deletions MTGOSDK/src/Core/Reflection/EventWrapper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
SPDX-License-Identifier: Apache-2.0
**/


namespace MTGOSDK.Core.Reflection;

/// <summary>
Expand All @@ -22,6 +21,8 @@ public class EventWrapper<T>(EventHandler handler) where T : EventArgs
/// </remarks>
public void Handle(object sender, T args)
{
Task.Run(() => handler.Invoke(sender, args));
Action callback = () => handler.Invoke(sender, args);
// Task.Run(callback);
SyncThread.Enqueue(callback);
}
}
2 changes: 1 addition & 1 deletion MTGOSDK/src/Core/Reflection/Proxy/EventHookProxy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public EventHookProxy(string typeName, string methodName, EventHook hook)

try
{
_eventHook?.Invoke((I)res?.Item1, (T)res?.Item2);
_eventHook?.Invoke(res?.Item1, res?.Item2);
}
catch (Exception e)
{
Expand Down
Loading

0 comments on commit 9d1a81a

Please sign in to comment.