diff --git a/CMakeLists.txt b/CMakeLists.txt index 8a8132a..6a5300b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -93,6 +93,8 @@ set(safetyhook_SOURCES "src/inline_hook.cpp" "src/mid_hook.cpp" "src/thread_freezer.cpp" + "src/utility.cpp" + "src/vmt_hook.cpp" cmake.toml ) @@ -489,6 +491,7 @@ if(SAFETYHOOK_BUILD_TESTS) # build-tests "unittest/inline_hook.cpp" "unittest/inline_hook.x86_64.cpp" "unittest/mid_hook.cpp" + "unittest/vmt_hook.cpp" cmake.toml ) @@ -557,6 +560,7 @@ if(SAFETYHOOK_BUILD_TESTS) # build-tests "unittest/inline_hook.cpp" "unittest/inline_hook.x86_64.cpp" "unittest/mid_hook.cpp" + "unittest/vmt_hook.cpp" "amalgamated-dist/safetyhook.cpp" cmake.toml ) diff --git a/include/safetyhook.hpp b/include/safetyhook.hpp index 4baf2a6..1206908 100644 --- a/include/safetyhook.hpp +++ b/include/safetyhook.hpp @@ -4,9 +4,12 @@ #include <safetyhook/inline_hook.hpp> #include <safetyhook/mid_hook.hpp> #include <safetyhook/thread_freezer.hpp> +#include <safetyhook/vmt_hook.hpp> using SafetyHookContext = safetyhook::Context; using SafetyHookInline = safetyhook::InlineHook; using SafetyHookMid = safetyhook::MidHook; using SafetyInlineHook [[deprecated("Use SafetyHookInline instead.")]] = safetyhook::InlineHook; using SafetyMidHook [[deprecated("Use SafetyHookMid instead.")]] = safetyhook::MidHook; +using SafetyHookVmt = safetyhook::VmtHook; +using SafetyHookVm = safetyhook::VmHook; diff --git a/include/safetyhook/allocator.hpp b/include/safetyhook/allocator.hpp index fb88366..758d856 100644 --- a/include/safetyhook/allocator.hpp +++ b/include/safetyhook/allocator.hpp @@ -124,4 +124,4 @@ class Allocator final : public std::enable_shared_from_this<Allocator> { [[nodiscard]] static bool in_range( uint8_t* address, const std::vector<uint8_t*>& desired_addresses, size_t max_distance); }; -} // namespace safetyhook \ No newline at end of file +} // namespace safetyhook diff --git a/include/safetyhook/easy.hpp b/include/safetyhook/easy.hpp index 93d1c63..8e12f9f 100644 --- a/include/safetyhook/easy.hpp +++ b/include/safetyhook/easy.hpp @@ -5,6 +5,8 @@ #include <safetyhook/inline_hook.hpp> #include <safetyhook/mid_hook.hpp> +#include <safetyhook/utility.hpp> +#include <safetyhook/vmt_hook.hpp> namespace safetyhook { /// @brief Easy to use API for creating an InlineHook. @@ -14,13 +16,10 @@ namespace safetyhook { [[nodiscard]] InlineHook create_inline(void* target, void* destination); /// @brief Easy to use API for creating an InlineHook. -/// @tparam T The type of the function to hook. /// @param target The address of the function to hook. /// @param destination The address of the destination function. /// @return The InlineHook object. -template <typename T> - requires std::is_function_v<T> -[[nodiscard]] InlineHook create_inline(T* target, T* destination) { +[[nodiscard]] InlineHook create_inline(FnPtr auto target, FnPtr auto destination) { return create_inline(reinterpret_cast<void*>(target), reinterpret_cast<void*>(destination)); } @@ -31,13 +30,29 @@ template <typename T> [[nodiscard]] MidHook create_mid(void* target, MidHookFn destination); /// @brief Easy to use API for creating a MidHook. -/// @tparam T The type of the function to hook. /// @param target the address of the function to hook. /// @param destination The destination function. /// @return The MidHook object. -template <typename T> - requires std::is_function_v<T> -[[nodiscard]] MidHook create_mid(T* target, MidHookFn destination) { +[[nodiscard]] MidHook create_mid(FnPtr auto target, MidHookFn destination) { return create_mid(reinterpret_cast<void*>(target), destination); } + +/// @brief Easy to use API for creating a VmtHook. +/// @param object The object to hook. +/// @return The VmtHook object. +[[nodiscard]] VmtHook create_vmt(void* object); + +/// @brief Easy to use API for creating a VmHook. +/// @param vmt The VmtHook to use to create the VmHook. +/// @param index The index of the method to hook. +/// @param destination The destination function. +/// @return The VmHook object. +[[nodiscard]] VmHook create_vm(VmtHook& vmt, size_t index, FnPtr auto destination) { + if (auto hook = vmt.hook_method(index, destination)) { + return std::move(*hook); + } else { + return {}; + } +} + } // namespace safetyhook \ No newline at end of file diff --git a/include/safetyhook/inline_hook.hpp b/include/safetyhook/inline_hook.hpp index f66b87f..b773d99 100644 --- a/include/safetyhook/inline_hook.hpp +++ b/include/safetyhook/inline_hook.hpp @@ -11,6 +11,7 @@ #include <vector> #include <safetyhook/allocator.hpp> +#include <safetyhook/utility.hpp> namespace safetyhook { /// @brief An inline hook. @@ -78,15 +79,12 @@ class InlineHook final { [[nodiscard]] static std::expected<InlineHook, Error> create(void* target, void* destination); /// @brief Create an inline hook. - /// @tparam T The type of the function to hook. /// @param target The address of the function to hook. /// @param destination The destination address. /// @return The InlineHook or an InlineHook::Error if an error occurred. /// @note This will use the default global Allocator. /// @note If you don't care about error handling, use the easy API (safetyhook::create_inline). - template <typename T> - requires std::is_function_v<T> - [[nodiscard]] static std::expected<InlineHook, Error> create(T* target, T* destination) { + [[nodiscard]] static std::expected<InlineHook, Error> create(FnPtr auto target, FnPtr auto destination) { return create(reinterpret_cast<void*>(target), reinterpret_cast<void*>(destination)); } @@ -100,16 +98,13 @@ class InlineHook final { const std::shared_ptr<Allocator>& allocator, void* target, void* destination); /// @brief Create an inline hook with a given Allocator. - /// @tparam T The type of the function to hook. /// @param allocator The allocator to use. /// @param target The address of the function to hook. /// @param destination The destination address. /// @return The InlineHook or an InlineHook::Error if an error occurred. /// @note If you don't care about error handling, use the easy API (safetyhook::create_inline). - template <typename T> - requires std::is_function_v<T> [[nodiscard]] static std::expected<InlineHook, Error> create( - const std::shared_ptr<Allocator>& allocator, T* target, T* destination) { + const std::shared_ptr<Allocator>& allocator, FnPtr auto target, FnPtr auto destination) { return create(allocator, reinterpret_cast<void*>(target), reinterpret_cast<void*>(destination)); } @@ -216,7 +211,7 @@ class InlineHook final { /// @return The result of calling the original function. /// @note This function will use the default calling convention set by your compiler. /// @note This function is unsafe because it doesn't lock the mutex. Only use this if you don't care about unhook - // safety or are worried about the performance cost of locking the mutex. + /// safety or are worried about the performance cost of locking the mutex. template <typename RetT = void, typename... Args> RetT unsafe_call(Args... args) { return original<RetT (*)(Args...)>()(args...); } @@ -228,7 +223,7 @@ class InlineHook final { /// @return The result of calling the original function. /// @note This function will use the __cdecl calling convention. /// @note This function is unsafe because it doesn't lock the mutex. Only use this if you don't care about unhook - // safety or are worried about the performance cost of locking the mutex. + /// safety or are worried about the performance cost of locking the mutex. template <typename RetT = void, typename... Args> RetT unsafe_ccall(Args... args) { return original<RetT(__cdecl*)(Args...)>()(args...); } @@ -240,7 +235,7 @@ class InlineHook final { /// @return The result of calling the original function. /// @note This function will use the __thiscall calling convention. /// @note This function is unsafe because it doesn't lock the mutex. Only use this if you don't care about unhook - // safety or are worried about the performance cost of locking the mutex. + /// safety or are worried about the performance cost of locking the mutex. template <typename RetT = void, typename... Args> RetT unsafe_thiscall(Args... args) { return original<RetT(__thiscall*)(Args...)>()(args...); } @@ -252,7 +247,7 @@ class InlineHook final { /// @return The result of calling the original function. /// @note This function will use the __stdcall calling convention. /// @note This function is unsafe because it doesn't lock the mutex. Only use this if you don't care about unhook - // safety or are worried about the performance cost of locking the mutex. + /// safety or are worried about the performance cost of locking the mutex. template <typename RetT = void, typename... Args> RetT unsafe_stdcall(Args... args) { return original<RetT(__stdcall*)(Args...)>()(args...); } @@ -264,7 +259,7 @@ class InlineHook final { /// @return The result of calling the original function. /// @note This function will use the __fastcall calling convention. /// @note This function is unsafe because it doesn't lock the mutex. Only use this if you don't care about unhook - // safety or are worried about the performance cost of locking the mutex. + /// safety or are worried about the performance cost of locking the mutex. template <typename RetT = void, typename... Args> RetT unsafe_fastcall(Args... args) { return original<RetT(__fastcall*)(Args...)>()(args...); } @@ -287,4 +282,4 @@ class InlineHook final { void destroy(); }; -} // namespace safetyhook \ No newline at end of file +} // namespace safetyhook diff --git a/include/safetyhook/mid_hook.hpp b/include/safetyhook/mid_hook.hpp index c5574b4..4aa927c 100644 --- a/include/safetyhook/mid_hook.hpp +++ b/include/safetyhook/mid_hook.hpp @@ -9,6 +9,7 @@ #include <safetyhook/allocator.hpp> #include <safetyhook/context.hpp> #include <safetyhook/inline_hook.hpp> +#include <safetyhook/utility.hpp> namespace safetyhook { @@ -56,15 +57,12 @@ class MidHook final { [[nodiscard]] static std::expected<MidHook, Error> create(void* target, MidHookFn destination); /// @brief Creates a new MidHook object. - /// @tparam T The type of the function to hook. /// @param target The address of the function to hook. /// @param destination The destination function. /// @return The MidHook object or a MidHook::Error if an error occurred. /// @note This will use the default global Allocator. /// @note If you don't care about error handling, use the easy API (safetyhook::create_mid). - template <typename T> - requires std::is_function_v<T> - [[nodiscard]] static std::expected<MidHook, Error> create(T* target, MidHookFn destination) { + [[nodiscard]] static std::expected<MidHook, Error> create(FnPtr auto target, MidHookFn destination) { return create(reinterpret_cast<void*>(target), destination); } @@ -84,10 +82,8 @@ class MidHook final { /// @param destination The destination function. /// @return The MidHook object or a MidHook::Error if an error occurred. /// @note If you don't care about error handling, use the easy API (safetyhook::create_mid). - template <typename T> - requires std::is_function_v<T> [[nodiscard]] static std::expected<MidHook, Error> create( - const std::shared_ptr<Allocator>& allocator, T* target, MidHookFn destination) { + const std::shared_ptr<Allocator>& allocator, FnPtr auto target, MidHookFn destination) { return create(allocator, reinterpret_cast<void*>(target), destination); } @@ -128,4 +124,4 @@ class MidHook final { std::expected<void, Error> setup( const std::shared_ptr<Allocator>& allocator, uint8_t* target, MidHookFn destination); }; -} // namespace safetyhook \ No newline at end of file +} // namespace safetyhook diff --git a/include/safetyhook/thread_freezer.hpp b/include/safetyhook/thread_freezer.hpp index 4cd5127..9f5d45f 100644 --- a/include/safetyhook/thread_freezer.hpp +++ b/include/safetyhook/thread_freezer.hpp @@ -5,7 +5,6 @@ #include <cstdint> #include <functional> -#include <vector> #include <Windows.h> diff --git a/include/safetyhook/utility.hpp b/include/safetyhook/utility.hpp index 277266e..25b3707 100644 --- a/include/safetyhook/utility.hpp +++ b/include/safetyhook/utility.hpp @@ -2,9 +2,15 @@ #include <algorithm> #include <cstdint> +#include <type_traits> namespace safetyhook { template <typename T> constexpr void store(uint8_t* address, const T& value) { std::copy_n(reinterpret_cast<const uint8_t*>(&value), sizeof(T), address); } -} // namespace safetyhook \ No newline at end of file + +template <typename T> +concept FnPtr = requires(T f) { std::is_pointer_v<T>&& std::is_function_v<std::remove_pointer_t<T>>; }; + +bool is_executable(uint8_t* address); +} // namespace safetyhook diff --git a/include/safetyhook/vmt_hook.hpp b/include/safetyhook/vmt_hook.hpp new file mode 100644 index 0000000..f35d560 --- /dev/null +++ b/include/safetyhook/vmt_hook.hpp @@ -0,0 +1,162 @@ +/// @file safetyhook/vmt_hook.hpp +/// @brief VMT hooking classes + +#pragma once + +#include <cstdint> +#include <expected> +#include <unordered_map> + +#include <safetyhook/allocator.hpp> +#include <safetyhook/utility.hpp> + +namespace safetyhook { +/// @brief A hook class that allows for hooking a single method in a VMT. +class VmHook final { +public: + VmHook() = default; + VmHook(const VmHook&) = delete; + VmHook(VmHook&& other) noexcept; + VmHook& operator=(const VmHook&) = delete; + VmHook& operator=(VmHook&& other) noexcept; + ~VmHook(); + + /// @brief Removes the hook. + void reset(); + + /// @brief Gets the original method pointer. + template <typename T> [[nodiscard]] T original() const { return reinterpret_cast<T>(m_original_vm); } + + /// @brief Calls the original method. + /// @tparam RetT The return type of the method. + /// @tparam Args The argument types of the method. + /// @param args The arguments to pass to the method. + /// @return The return value of the method. + /// @note This will call the original method with the default calling convention. + template <typename RetT = void, typename... Args> RetT call(Args... args) { + return original<RetT (*)(Args...)>()(args...); + } + + /// @brief Calls the original method with the __cdecl calling convention. + /// @tparam RetT The return type of the method. + /// @tparam Args The argument types of the method. + /// @param args The arguments to pass to the method. + /// @return The return value of the method. + template <typename RetT = void, typename... Args> RetT ccall(Args... args) { + return original<RetT(__cdecl*)(Args...)>()(args...); + } + + /// @brief Calls the original method with the __thiscall calling convention. + /// @tparam RetT The return type of the method. + /// @tparam Args The argument types of the method. + /// @param args The arguments to pass to the method. + /// @return The return value of the method. + template <typename RetT = void, typename... Args> RetT thiscall(Args... args) { + return original<RetT(__thiscall*)(Args...)>()(args...); + } + + /// @brief Calls the original method with the __stdcall calling convention. + /// @tparam RetT The return type of the method. + /// @tparam Args The argument types of the method. + /// @param args The arguments to pass to the method. + /// @return The return value of the method. + template <typename RetT = void, typename... Args> RetT stdcall(Args... args) { + return original<RetT(__stdcall*)(Args...)>()(args...); + } + + /// @brief Calls the original method with the __fastcall calling convention. + /// @tparam RetT The return type of the method. + /// @tparam Args The argument types of the method. + /// @param args The arguments to pass to the method. + /// @return The return value of the method. + template <typename RetT = void, typename... Args> RetT fastcall(Args... args) { + return original<RetT(__fastcall*)(Args...)>()(args...); + } + +private: + friend class VmtHook; + + uint8_t* m_original_vm{}; + uint8_t* m_new_vm{}; + uint8_t** m_vmt_entry{}; + + // This keeps the allocation alive until the hook is destroyed. + std::shared_ptr<Allocation> m_new_vmt_allocation{}; + + void destroy(); +}; + +/// @brief A hook class that copies an entire VMT for a given object and replaces it. +class VmtHook final { +public: + /// @brief Error type for VmtHook. + struct Error { + /// @brief The type of error. + enum : uint8_t { + BAD_ALLOCATION, ///< An error occurred while allocating memory. + } type; + + /// @brief Extra error information. + union { + Allocator::Error allocator_error; ///< Allocator error information. + }; + + /// @brief Create a BAD_ALLOCATION error. + /// @param err The Allocator::Error that failed. + /// @return The new BAD_ALLOCATION error. + [[nodiscard]] static Error bad_allocation(Allocator::Error err) { + return {.type = BAD_ALLOCATION, .allocator_error = err}; + } + }; + + /// @brief Creates a new VmtHook object. Will clone the VMT of the given object and replace it. + /// @param object The object to hook. + /// @return The VmtHook object or a VmtHook::Error if an error occurred. + [[nodiscard]] static std::expected<VmtHook, Error> create(void* object); + + VmtHook() = default; + VmtHook(const VmtHook&) = delete; + VmtHook(VmtHook&& other) noexcept; + VmtHook& operator=(const VmtHook&) = delete; + VmtHook& operator=(VmtHook&& other) noexcept; + ~VmtHook(); + + /// @brief Applies the hook. + /// @param object The object to apply the hook to. + /// @note This will replace the VMT of the object with the new VMT. You can apply the hook to multiple objects. + void apply(void* object); + + /// @brief Removes the hook. + /// @param object The object to remove the hook from. + void remove(void* object); + + /// @brief Removes the hook from all objects. + void reset(); + + /// @brief Hooks a method in the VMT. + /// @param index The index of the method to hook. + /// @param new_function The new function to use. + [[nodiscard]] std::expected<VmHook, Error> hook_method(size_t index, FnPtr auto new_function) { + VmHook hook{}; + + ++index; // Skip RTTI pointer. + hook.m_original_vm = m_new_vmt[index]; + store(reinterpret_cast<uint8_t*>(&hook.m_new_vm), new_function); + hook.m_vmt_entry = &m_new_vmt[index]; + hook.m_new_vmt_allocation = m_new_vmt_allocation; + m_new_vmt[index] = hook.m_new_vm; + + return hook; + } + +private: + // Map of object instance to their original VMT. + std::unordered_map<void*, uint8_t**> m_objects{}; + + // The allocation is a shared_ptr, so it can be shared with VmHooks to ensure the memory is kept alive. + std::shared_ptr<Allocation> m_new_vmt_allocation{}; + uint8_t** m_new_vmt{}; + + void destroy(); +}; +} // namespace safetyhook diff --git a/src/easy.cpp b/src/easy.cpp index 2508007..f79fbbb 100644 --- a/src/easy.cpp +++ b/src/easy.cpp @@ -17,4 +17,11 @@ MidHook create_mid(void* target, MidHookFn destination) { } } +VmtHook create_vmt(void* object) { + if (auto hook = VmtHook::create(object)) { + return std::move(*hook); + } else { + return {}; + } +} } // namespace safetyhook \ No newline at end of file diff --git a/src/utility.cpp b/src/utility.cpp new file mode 100644 index 0000000..616204a --- /dev/null +++ b/src/utility.cpp @@ -0,0 +1,50 @@ +#include <Windows.h> + +#include <safetyhook/utility.hpp> + +namespace safetyhook { +bool is_page_executable(uint8_t* address) { + MEMORY_BASIC_INFORMATION mbi; + + if (VirtualQuery(address, &mbi, sizeof(mbi)) == 0) { + return false; + } + + const auto executable_protect = PAGE_EXECUTE | PAGE_EXECUTE_READ | PAGE_EXECUTE_READWRITE | PAGE_EXECUTE_WRITECOPY; + + return (mbi.Protect & executable_protect) != 0; +} + +bool is_executable(uint8_t* address) { + LPVOID image_base_ptr; + + if (RtlPcToFileHeader(address, &image_base_ptr) == nullptr) { + return is_page_executable(address); + } + + // Just check if the section is executable. + const auto* image_base = reinterpret_cast<uint8_t*>(image_base_ptr); + const auto* dos_hdr = reinterpret_cast<const IMAGE_DOS_HEADER*>(image_base); + + if (dos_hdr->e_magic != IMAGE_DOS_SIGNATURE) { + return is_page_executable(address); + } + + const auto* nt_hdr = reinterpret_cast<const IMAGE_NT_HEADERS*>(image_base + dos_hdr->e_lfanew); + + if (nt_hdr->Signature != IMAGE_NT_SIGNATURE) { + return is_page_executable(address); + } + + const auto* section = IMAGE_FIRST_SECTION(nt_hdr); + + for (auto i = 0; i < nt_hdr->FileHeader.NumberOfSections; ++i, ++section) { + if (address >= image_base + section->VirtualAddress && + address < image_base + section->VirtualAddress + section->Misc.VirtualSize) { + return (section->Characteristics & IMAGE_SCN_MEM_EXECUTE) != 0; + } + } + + return is_page_executable(address); +} +} // namespace safetyhook \ No newline at end of file diff --git a/src/vmt_hook.cpp b/src/vmt_hook.cpp new file mode 100644 index 0000000..cec6512 --- /dev/null +++ b/src/vmt_hook.cpp @@ -0,0 +1,147 @@ +#include <Windows.h> + +#include <safetyhook/thread_freezer.hpp> + +#include <safetyhook/vmt_hook.hpp> + +namespace safetyhook { +VmHook::VmHook(VmHook&& other) noexcept { + *this = std::move(other); +} + +VmHook& VmHook::operator=(VmHook&& other) noexcept { + destroy(); + m_original_vm = other.m_original_vm; + m_new_vm = other.m_new_vm; + m_vmt_entry = other.m_vmt_entry; + m_new_vmt_allocation = std::move(other.m_new_vmt_allocation); + other.m_original_vm = nullptr; + other.m_new_vm = nullptr; + other.m_vmt_entry = nullptr; + return *this; +} + +VmHook::~VmHook() { + destroy(); +} + +void VmHook::reset() { + *this = {}; +} + +void VmHook::destroy() { + if (m_original_vm != nullptr) { + *m_vmt_entry = m_original_vm; + m_original_vm = nullptr; + m_new_vm = nullptr; + m_vmt_entry = nullptr; + m_new_vmt_allocation.reset(); + } +} + +std::expected<VmtHook, VmtHook::Error> VmtHook::create(void* object) { + VmtHook hook{}; + + const auto original_vmt = *reinterpret_cast<uint8_t***>(object); + hook.m_objects.emplace(object, original_vmt); + + // Count the number of virtual method pointers. We start at one to account for the RTTI pointer. + auto num_vmt_entries = 1; + + for (auto vm = original_vmt; is_executable(*vm); ++vm) { + ++num_vmt_entries; + } + + // Allocate memory for the new VMT. + auto allocation = Allocator::global()->allocate(num_vmt_entries * sizeof(uint8_t*)); + + if (!allocation) { + return std::unexpected{Error::bad_allocation(allocation.error())}; + } + + hook.m_new_vmt_allocation = std::make_shared<Allocation>(std::move(*allocation)); + hook.m_new_vmt = reinterpret_cast<uint8_t**>(hook.m_new_vmt_allocation->data()); + + // Copy pointer to RTTI. + hook.m_new_vmt[0] = original_vmt[-1]; + + // Copy virtual method pointers. + for (auto i = 0; i < num_vmt_entries - 1; ++i) { + hook.m_new_vmt[i + 1] = original_vmt[i]; + } + + *reinterpret_cast<uint8_t***>(object) = &hook.m_new_vmt[1]; + + return hook; +} + +VmtHook::VmtHook(VmtHook&& other) noexcept { + *this = std::move(other); +} + +VmtHook& VmtHook::operator=(VmtHook&& other) noexcept { + destroy(); + m_objects = std::move(other.m_objects); + m_new_vmt_allocation = std::move(other.m_new_vmt_allocation); + m_new_vmt = other.m_new_vmt; + other.m_new_vmt = nullptr; + return *this; +} + +VmtHook::~VmtHook() { + destroy(); +} + +void VmtHook::apply(void* object) { + m_objects.emplace(object, *reinterpret_cast<uint8_t***>(object)); + *reinterpret_cast<uint8_t***>(object) = &m_new_vmt[1]; +} + +void VmtHook::remove(void* object) { + const auto search = m_objects.find(object); + + if (search == m_objects.end()) { + return; + } + + const auto original_vmt = search->second; + + execute_while_frozen([&] { + if (IsBadWritePtr(object, sizeof(void*))) { + return; + } + + if (*reinterpret_cast<uint8_t***>(object) != &m_new_vmt[1]) { + return; + } + + *reinterpret_cast<uint8_t***>(object) = original_vmt; + }); + + m_objects.erase(search); +} + +void VmtHook::reset() { + *this = {}; +} + +void VmtHook::destroy() { + execute_while_frozen([this] { + for (const auto [object, original_vmt] : m_objects) { + if (IsBadWritePtr(object, sizeof(void*))) { + return; + } + + if (*reinterpret_cast<uint8_t***>(object) != &m_new_vmt[1]) { + return; + } + + *reinterpret_cast<uint8_t***>(object) = original_vmt; + } + }); + + m_objects.clear(); + m_new_vmt_allocation.reset(); + m_new_vmt = nullptr; +} +} // namespace safetyhook \ No newline at end of file diff --git a/unittest/vmt_hook.cpp b/unittest/vmt_hook.cpp new file mode 100644 index 0000000..3365ebe --- /dev/null +++ b/unittest/vmt_hook.cpp @@ -0,0 +1,331 @@ +#include <catch2/catch_test_macros.hpp> +#include <safetyhook.hpp> + +TEST_CASE("VMT hook an object instance", "[vmt_hook]") { + struct Interface { + virtual ~Interface() = default; + virtual int add_42(int a) = 0; + }; + + struct Target : Interface { + __declspec(noinline) int add_42(int a) override { return a + 42; } + }; + + std::unique_ptr<Interface> target = std::make_unique<Target>(); + + REQUIRE(target->add_42(0) == 42); + + static SafetyHookVmt target_hook{}; + static SafetyHookVm add_42_hook{}; + + struct Hook : Target { + int hooked_add_42(int a) { return add_42_hook.thiscall<int>(this, a) + 1337; } + }; + + auto vmt_result = SafetyHookVmt::create(target.get()); + + REQUIRE(vmt_result); + + target_hook = std::move(*vmt_result); + + auto vm_result = target_hook.hook_method(1, &Hook::hooked_add_42); + + REQUIRE(vm_result); + + add_42_hook = std::move(*vm_result); + + REQUIRE(target->add_42(1) == 1380); + + add_42_hook.reset(); + + REQUIRE(target->add_42(2) == 44); +} + +TEST_CASE("Resetting the VMT hook removes all VM hooks for that object", "[vmt_hook]") { + struct Interface { + virtual ~Interface() = default; + virtual int add_42(int a) = 0; + virtual int add_43(int a) = 0; + }; + + struct Target : Interface { + __declspec(noinline) int add_42(int a) override { return a + 42; } + __declspec(noinline) int add_43(int a) override { return a + 43; } + }; + + std::unique_ptr<Interface> target = std::make_unique<Target>(); + + REQUIRE(target->add_42(0) == 42); + REQUIRE(target->add_43(0) == 43); + + static SafetyHookVmt target_hook{}; + static SafetyHookVm add_42_hook{}; + static SafetyHookVm add_43_hook{}; + + struct Hook : Target { + int hooked_add_42(int a) { return add_42_hook.thiscall<int>(this, a) + 1337; } + int hooked_add_43(int a) { return add_43_hook.thiscall<int>(this, a) + 1337; } + }; + + auto vmt_result = SafetyHookVmt::create(target.get()); + + REQUIRE(vmt_result); + + target_hook = std::move(*vmt_result); + + auto vm_result = target_hook.hook_method(1, &Hook::hooked_add_42); + + REQUIRE(vm_result); + + add_42_hook = std::move(*vm_result); + + REQUIRE(target->add_42(1) == 1380); + + vm_result = target_hook.hook_method(2, &Hook::hooked_add_43); + + REQUIRE(vm_result); + + add_43_hook = std::move(*vm_result); + + REQUIRE(target->add_43(1) == 1381); + + target_hook.reset(); + + REQUIRE(target->add_42(2) == 44); + REQUIRE(target->add_43(2) == 45); +} + +TEST_CASE("VMT hooking an object maintains correct RTTI", "[vmt_hook]") { + struct Interface { + virtual ~Interface() = default; + virtual int add_42(int a) = 0; + }; + + struct Target : Interface { + __declspec(noinline) int add_42(int a) override { return a + 42; } + }; + + std::unique_ptr<Interface> target = std::make_unique<Target>(); + + REQUIRE(target->add_42(0) == 42); + REQUIRE(dynamic_cast<Target*>(target.get()) != nullptr); + + static SafetyHookVmt target_hook{}; + static SafetyHookVm add_42_hook{}; + + struct Hook : Target { + int hooked_add_42(int a) { return add_42_hook.thiscall<int>(this, a) + 1337; } + }; + + auto vmt_result = SafetyHookVmt::create(target.get()); + + REQUIRE(vmt_result); + + target_hook = std::move(*vmt_result); + + auto vm_result = target_hook.hook_method(1, &Hook::hooked_add_42); + + REQUIRE(vm_result); + + add_42_hook = std::move(*vm_result); + + REQUIRE(target->add_42(1) == 1380); + REQUIRE(dynamic_cast<Target*>(target.get()) != nullptr); +} + +TEST_CASE("Can safely destroy VmtHook after object is deleted", "[vmt_hook]") { + struct Interface { + virtual ~Interface() = default; + virtual int add_42(int a) = 0; + }; + + struct Target : Interface { + __declspec(noinline) int add_42(int a) override { return a + 42; } + }; + + std::unique_ptr<Interface> target = std::make_unique<Target>(); + + REQUIRE(target->add_42(0) == 42); + + static SafetyHookVmt target_hook{}; + static SafetyHookVm add_42_hook{}; + + struct Hook : Target { + int hooked_add_42(int a) { return add_42_hook.thiscall<int>(this, a) + 1337; } + }; + + auto vmt_result = SafetyHookVmt::create(target.get()); + + REQUIRE(vmt_result); + + target_hook = std::move(*vmt_result); + + auto vm_result = target_hook.hook_method(1, &Hook::hooked_add_42); + + REQUIRE(vm_result); + + add_42_hook = std::move(*vm_result); + + REQUIRE(target->add_42(1) == 1380); + + target.reset(); + target_hook.reset(); +} + +TEST_CASE("Can apply an existing VMT hook to more than one object", "[vmt_hook]") { + struct Interface { + virtual ~Interface() = default; + virtual int add_42(int a) = 0; + }; + + struct Target : Interface { + __declspec(noinline) int add_42(int a) override { return a + 42; } + }; + + std::unique_ptr<Interface> target = std::make_unique<Target>(); + std::unique_ptr<Interface> target0 = std::make_unique<Target>(); + std::unique_ptr<Interface> target1 = std::make_unique<Target>(); + std::unique_ptr<Interface> target2 = std::make_unique<Target>(); + + REQUIRE(target->add_42(0) == 42); + + static SafetyHookVmt target_hook{}; + static SafetyHookVm add_42_hook{}; + + struct Hook : Target { + int hooked_add_42(int a) { return add_42_hook.thiscall<int>(this, a) + 1337; } + }; + + auto vmt_result = SafetyHookVmt::create(target.get()); + + REQUIRE(vmt_result); + + target_hook = std::move(*vmt_result); + + auto vm_result = target_hook.hook_method(1, &Hook::hooked_add_42); + + REQUIRE(vm_result); + + add_42_hook = std::move(*vm_result); + + target_hook.apply(target0.get()); + target_hook.apply(target1.get()); + target_hook.apply(target2.get()); + + REQUIRE(target->add_42(1) == 1380); + REQUIRE(target0->add_42(1) == 1380); + REQUIRE(target1->add_42(1) == 1380); + REQUIRE(target2->add_42(1) == 1380); + + add_42_hook.reset(); + + REQUIRE(target->add_42(2) == 44); + REQUIRE(target0->add_42(2) == 44); + REQUIRE(target1->add_42(2) == 44); + REQUIRE(target2->add_42(2) == 44); +} + +TEST_CASE("Can remove an object that was previously VMT hooked", "[vmt_hook]") { + struct Interface { + virtual ~Interface() = default; + virtual int add_42(int a) = 0; + }; + + struct Target : Interface { + __declspec(noinline) int add_42(int a) override { return a + 42; } + }; + + std::unique_ptr<Interface> target = std::make_unique<Target>(); + std::unique_ptr<Interface> target0 = std::make_unique<Target>(); + std::unique_ptr<Interface> target1 = std::make_unique<Target>(); + std::unique_ptr<Interface> target2 = std::make_unique<Target>(); + + REQUIRE(target->add_42(0) == 42); + + static SafetyHookVmt target_hook{}; + static SafetyHookVm add_42_hook{}; + + struct Hook : Target { + int hooked_add_42(int a) { return add_42_hook.thiscall<int>(this, a) + 1337; } + }; + + auto vmt_result = SafetyHookVmt::create(target.get()); + + REQUIRE(vmt_result); + + target_hook = std::move(*vmt_result); + + auto vm_result = target_hook.hook_method(1, &Hook::hooked_add_42); + + REQUIRE(vm_result); + + add_42_hook = std::move(*vm_result); + + target_hook.apply(target0.get()); + target_hook.apply(target1.get()); + target_hook.apply(target2.get()); + + REQUIRE(target->add_42(1) == 1380); + REQUIRE(target0->add_42(1) == 1380); + REQUIRE(target1->add_42(1) == 1380); + REQUIRE(target2->add_42(1) == 1380); + + target_hook.remove(target0.get()); + + REQUIRE(target->add_42(2) == 1381); + REQUIRE(target0->add_42(2) == 44); + REQUIRE(target1->add_42(2) == 1381); + REQUIRE(target2->add_42(2) == 1381); + + target_hook.remove(target2.get()); + + REQUIRE(target->add_42(2) == 1381); + REQUIRE(target0->add_42(2) == 44); + REQUIRE(target1->add_42(2) == 1381); + REQUIRE(target2->add_42(2) == 44); + + target_hook.remove(target.get()); + + REQUIRE(target->add_42(2) == 44); + REQUIRE(target0->add_42(2) == 44); + REQUIRE(target1->add_42(2) == 1381); + REQUIRE(target2->add_42(2) == 44); + + target_hook.remove(target1.get()); + + REQUIRE(target->add_42(2) == 44); + REQUIRE(target0->add_42(2) == 44); + REQUIRE(target1->add_42(2) == 44); + REQUIRE(target2->add_42(2) == 44); +} + +TEST_CASE("VMT hook an object instance with easy API", "[vmt_hook]") { + struct Interface { + virtual ~Interface() = default; + virtual int add_42(int a) = 0; + }; + + struct Target : Interface { + __declspec(noinline) int add_42(int a) override { return a + 42; } + }; + + std::unique_ptr<Interface> target = std::make_unique<Target>(); + + REQUIRE(target->add_42(0) == 42); + + static SafetyHookVmt target_hook{}; + static SafetyHookVm add_42_hook{}; + + struct Hook : Target { + int hooked_add_42(int a) { return add_42_hook.thiscall<int>(this, a) + 1337; } + }; + + target_hook = safetyhook::create_vmt(target.get()); + add_42_hook = safetyhook::create_vm(target_hook, 1, &Hook::hooked_add_42); + + REQUIRE(target->add_42(1) == 1380); + + add_42_hook.reset(); + + REQUIRE(target->add_42(2) == 44); +}