From 79ba8c6cf4e22ec7efa4d016d7d6305501c8a8a8 Mon Sep 17 00:00:00 2001 From: Skull <86374920+skkuull@users.noreply.github.com> Date: Mon, 17 Mar 2025 06:08:13 +0300 Subject: [PATCH] tweak(loader): refactor and update --- src/client/loader/component_interface.hpp | 2 +- src/client/loader/loader.cpp | 354 +++++++++++----------- src/client/loader/loader.hpp | 20 +- src/client/loader/seh.cpp | 151 --------- src/client/loader/seh.hpp | 38 --- src/client/loader/tls.cpp | 10 +- src/client/main.cpp | 126 ++++++-- src/common/utils/io.cpp | 4 +- src/common/utils/nt.cpp | 87 +++--- src/common/utils/nt.hpp | 108 ++----- 10 files changed, 371 insertions(+), 529 deletions(-) delete mode 100644 src/client/loader/seh.cpp delete mode 100644 src/client/loader/seh.hpp diff --git a/src/client/loader/component_interface.hpp b/src/client/loader/component_interface.hpp index c0f91c66..9e6cef79 100644 --- a/src/client/loader/component_interface.hpp +++ b/src/client/loader/component_interface.hpp @@ -4,7 +4,7 @@ enum class component_priority { min = 0, dvars, - steam_proxy, + uwp, arxan, updater, }; diff --git a/src/client/loader/loader.cpp b/src/client/loader/loader.cpp index 4ec12c7c..4a2c1fae 100644 --- a/src/client/loader/loader.cpp +++ b/src/client/loader/loader.cpp @@ -1,208 +1,218 @@ #include #include "loader.hpp" -#include "seh.hpp" #include "tls.hpp" #include #include -FARPROC loader::load(const utils::nt::library& library, const std::string& buffer) const +namespace loader { - if (buffer.empty()) return nullptr; - - const utils::nt::library source(HMODULE(buffer.data())); - if (!source) return nullptr; - - this->load_sections(library, source); - this->load_imports(library, source); - this->load_exception_table(library, source); - this->load_tls(library, source); - - DWORD old_protect; - VirtualProtect(library.get_nt_headers(), 0x1000, PAGE_EXECUTE_READWRITE, &old_protect); - - library.get_optional_header()->DataDirectory[IMAGE_DIRECTORY_ENTRY_IMPORT] = source - .get_optional_header()->DataDirectory[ - IMAGE_DIRECTORY_ENTRY_IMPORT]; - std::memmove(library.get_nt_headers(), source.get_nt_headers(), - sizeof(IMAGE_NT_HEADERS) + source.get_nt_headers()->FileHeader.NumberOfSections * sizeof( - IMAGE_SECTION_HEADER)); - - return FARPROC(library.get_ptr() + source.get_relative_entry_point()); -} - -FARPROC loader::load_library(const std::string& filename, uint64_t* base_address) const -{ - const auto target = utils::nt::library::load(filename); - if (!target) + namespace { - throw std::runtime_error{"Failed to map binary!"}; - } - - const auto base = size_t(target.get_ptr()); - *base_address = base; - - this->load_imports(target, target); - this->load_tls(target, target); - - return FARPROC(target.get_ptr() + target.get_relative_entry_point()); -} - -void loader::set_import_resolver(const std::function& resolver) -{ - this->import_resolver_ = resolver; -} - -void loader::load_section(const utils::nt::library& target, const utils::nt::library& source, - IMAGE_SECTION_HEADER* section) -{ - void* target_ptr = target.get_ptr() + section->VirtualAddress; - const void* source_ptr = source.get_ptr() + section->PointerToRawData; - - if (PBYTE(target_ptr) >= (target.get_ptr() + BINARY_PAYLOAD_SIZE)) - { - throw std::runtime_error("Section exceeds the binary payload size, please increase it!"); - } - - if (section->SizeOfRawData > 0) - { - std::memmove(target_ptr, source_ptr, section->SizeOfRawData); - - DWORD old_protect; - VirtualProtect(target_ptr, section->Misc.VirtualSize, PAGE_EXECUTE_READWRITE, &old_protect); - } -} - -void loader::load_sections(const utils::nt::library& target, const utils::nt::library& source) const -{ - for (auto& section : source.get_section_headers()) - { - this->load_section(target, source, section); - } -} - -void loader::load_imports(const utils::nt::library& target, const utils::nt::library& source) const -{ - auto* const import_directory = &source.get_optional_header()->DataDirectory[IMAGE_DIRECTORY_ENTRY_IMPORT]; - - auto* descriptor = PIMAGE_IMPORT_DESCRIPTOR(target.get_ptr() + import_directory->VirtualAddress); - - while (descriptor->Name) - { - std::string name = LPSTR(target.get_ptr() + descriptor->Name); - - auto* name_table_entry = reinterpret_cast(target.get_ptr() + descriptor->OriginalFirstThunk); - auto* address_table_entry = reinterpret_cast(target.get_ptr() + descriptor->FirstThunk); - - if (!descriptor->OriginalFirstThunk) + template + T offset_pointer(void* data, const ptrdiff_t offset) { - name_table_entry = reinterpret_cast(target.get_ptr() + descriptor->FirstThunk); + return reinterpret_cast(reinterpret_cast(data) + offset); } - while (*name_table_entry) + std::function import_resolver_; + + void load_section(const utils::nt::library& target, const utils::nt::library& source, + IMAGE_SECTION_HEADER* section) { - FARPROC function = nullptr; - std::string function_name; - const char* function_procname; + void* target_ptr = target.get_ptr() + section->VirtualAddress; + const void* source_ptr = source.get_ptr() + section->PointerToRawData; - if (IMAGE_SNAP_BY_ORDINAL(*name_table_entry)) + if (PBYTE(target_ptr) >= (target.get_ptr() + BINARY_PAYLOAD_SIZE)) { - function_name = "#" + std::to_string(IMAGE_ORDINAL(*name_table_entry)); - function_procname = MAKEINTRESOURCEA(IMAGE_ORDINAL(*name_table_entry)); - } - else - { - auto* import = PIMAGE_IMPORT_BY_NAME(target.get_ptr() + *name_table_entry); - function_name = import->Name; - function_procname = function_name.data(); + throw std::runtime_error("Section exceeds the binary payload size, please increase it!"); } - if (this->import_resolver_) function = FARPROC(this->import_resolver_(name, function_name)); - if (!function) + if (section->SizeOfRawData > 0) { - auto library = utils::nt::library::load(name); - if (library) + std::memmove(target_ptr, source_ptr, section->SizeOfRawData); + + DWORD old_protect; + VirtualProtect(target_ptr, section->Misc.VirtualSize, PAGE_EXECUTE_READWRITE, &old_protect); + } + } + + void load_sections(const utils::nt::library& target, const utils::nt::library& source) + { + for (auto& section : source.get_section_headers()) + { + load_section(target, source, section); + } + } + + void load_imports(const utils::nt::library& target) + { + const auto* const import_directory = &target.get_optional_header()->DataDirectory[IMAGE_DIRECTORY_ENTRY_IMPORT]; + + auto* descriptor = PIMAGE_IMPORT_DESCRIPTOR(target.get_ptr() + import_directory->VirtualAddress); + + while (descriptor->Name) + { + std::string name = LPSTR(target.get_ptr() + descriptor->Name); + + auto* name_table_entry = reinterpret_cast(target.get_ptr() + descriptor->OriginalFirstThunk); + auto* address_table_entry = reinterpret_cast(target.get_ptr() + descriptor->FirstThunk); + + if (!descriptor->OriginalFirstThunk) { - function = GetProcAddress(library, function_procname); + name_table_entry = reinterpret_cast(target.get_ptr() + descriptor->FirstThunk); } + + while (*name_table_entry) + { + FARPROC function = nullptr; + std::string function_name; + const char* function_procname; + + if (IMAGE_SNAP_BY_ORDINAL(*name_table_entry)) + { + function_name = "#" + std::to_string(IMAGE_ORDINAL(*name_table_entry)); + function_procname = MAKEINTRESOURCEA(IMAGE_ORDINAL(*name_table_entry)); + } + else + { + auto* import = PIMAGE_IMPORT_BY_NAME(target.get_ptr() + *name_table_entry); + function_name = import->Name; + function_procname = function_name.data(); + } + + auto library = utils::nt::library::load(name); + if (library) + { + function = GetProcAddress(library, function_procname); + } + + if (!function) + { + throw std::runtime_error(utils::string::va("Unable to load import '%s' from library '%s'", + function_name.data(), name.data())); + } + + utils::hook::set(address_table_entry, reinterpret_cast(function)); + + name_table_entry++; + address_table_entry++; + } + + descriptor++; } - - if (!function) - { - throw std::runtime_error(utils::string::va("Unable to load import '%s' from library '%s'", - function_name.data(), name.data())); - } - - utils::hook::set(address_table_entry, reinterpret_cast(function)); - - name_table_entry++; - address_table_entry++; } - descriptor++; - } -} - -void loader::load_exception_table(const utils::nt::library& target, const utils::nt::library& source) const -{ - auto* exception_directory = &source.get_optional_header()->DataDirectory[IMAGE_DIRECTORY_ENTRY_EXCEPTION]; - - auto* function_list = PRUNTIME_FUNCTION(target.get_ptr() + exception_directory->VirtualAddress); - const auto entry_count = ULONG(exception_directory->Size / sizeof(RUNTIME_FUNCTION)); - - if (!RtlAddFunctionTable(function_list, entry_count, DWORD64(target.get_ptr()))) - { - MessageBoxA(nullptr, "Setting exception handlers failed.", "Error", MB_OK | MB_ICONERROR); - } - - { - const utils::nt::library ntdll("ntdll.dll"); - - auto* const table_list_head = ntdll.invoke_pascal("RtlGetFunctionTableListHead"); - auto* table_list_entry = table_list_head->Flink; - - while (table_list_entry != table_list_head) + void load_relocations(const utils::nt::library& target) { - auto* const function_table = CONTAINING_RECORD(table_list_entry, DYNAMIC_FUNCTION_TABLE, Links); - - if (function_table->BaseAddress == ULONG_PTR(target.get_handle())) + if (!utils::nt::is_wine()) { - function_table->EntryCount = entry_count; - function_table->FunctionTable = function_list; + return; } - table_list_entry = function_table->Links.Flink; + auto* current_base = target.get_ptr(); + const auto initial_base = target.get_optional_header()->ImageBase; + const auto delta = reinterpret_cast(current_base) - initial_base; + + PIMAGE_DATA_DIRECTORY directory = &target.get_optional_header()->DataDirectory[ + IMAGE_DIRECTORY_ENTRY_BASERELOC]; + if (directory->Size == 0) + { + return; + } + + auto* relocation = reinterpret_cast(current_base + directory->VirtualAddress); + while (relocation->VirtualAddress > 0) + { + unsigned char* dest = current_base + relocation->VirtualAddress; + + auto* rel_info = offset_pointer(relocation, sizeof(IMAGE_BASE_RELOCATION)); + const auto* rel_info_end = offset_pointer( + rel_info, relocation->SizeOfBlock - sizeof(IMAGE_BASE_RELOCATION)); + + for (; rel_info < rel_info_end; ++rel_info) + { + const int type = *rel_info >> 12; + const int offset = *rel_info & 0xfff; + + switch (type) + { + case IMAGE_REL_BASED_ABSOLUTE: + break; + + case IMAGE_REL_BASED_HIGHLOW: + { + auto* patch_address = reinterpret_cast(dest + offset); + utils::hook::set(patch_address, *patch_address + static_cast(delta)); + break; + } + + case IMAGE_REL_BASED_DIR64: + { + auto* patch_address = reinterpret_cast(dest + offset); + utils::hook::set(patch_address, *patch_address + static_cast(delta)); + break; + } + + default: + throw std::runtime_error("Unknown relocation type: " + std::to_string(type)); + } + } + + relocation = offset_pointer(relocation, relocation->SizeOfBlock); + } + } + + void load_tls(const utils::nt::library& target) + { + if (target.get_optional_header()->DataDirectory[IMAGE_DIRECTORY_ENTRY_TLS].Size) + { + auto* target_tls = tls::allocate_tls_index(); + auto* const source_tls = reinterpret_cast(target.get_ptr() + target.get_optional_header() + ->DataDirectory[IMAGE_DIRECTORY_ENTRY_TLS].VirtualAddress); + + auto* target_tls_start = PVOID(target_tls->StartAddressOfRawData); + auto* tls_start = PVOID(source_tls->StartAddressOfRawData); + const auto tls_size = source_tls->EndAddressOfRawData - source_tls->StartAddressOfRawData; + const auto tls_index = *reinterpret_cast(target_tls->AddressOfIndex); + + utils::hook::set(source_tls->AddressOfIndex, tls_index); + + if (target_tls->AddressOfCallBacks) + { + utils::hook::set(target_tls->AddressOfCallBacks, nullptr); + } + + DWORD old_protect; + VirtualProtect(target_tls_start, tls_size, PAGE_READWRITE, &old_protect); + + auto* const tls_base = *reinterpret_cast(__readgsqword(0x58) + 8ull * tls_index); + std::memmove(tls_base, tls_start, tls_size); + std::memmove(target_tls_start, tls_start, tls_size); + + VirtualProtect(target_tls, sizeof(*target_tls), PAGE_READWRITE, &old_protect); + *target_tls = *source_tls; + } } } - seh::setup_handler(target.get_ptr(), target.get_ptr() + source.get_optional_header()->SizeOfImage, function_list, - entry_count); -} - -void loader::load_tls(const utils::nt::library& target, const utils::nt::library& source) const -{ - if (source.get_optional_header()->DataDirectory[IMAGE_DIRECTORY_ENTRY_TLS].Size) + utils::nt::library load_binary(const std::string& filename) { - auto* target_tls = tls::allocate_tls_index(); - /* target_tls = reinterpret_cast(library.get_ptr() + library.get_optional_header() - ->DataDirectory[IMAGE_DIRECTORY_ENTRY_TLS].VirtualAddress); */ - auto* const source_tls = reinterpret_cast(target.get_ptr() + source.get_optional_header() - ->DataDirectory[IMAGE_DIRECTORY_ENTRY_TLS].VirtualAddress); + const auto target = utils::nt::library::load(filename); + if (!target) + { + throw std::runtime_error{ "Failed to map: " + filename }; + } - const auto tls_size = source_tls->EndAddressOfRawData - source_tls->StartAddressOfRawData; - const auto tls_index = *reinterpret_cast(target_tls->AddressOfIndex); - utils::hook::set(source_tls->AddressOfIndex, tls_index); + load_relocations(target); + load_imports(target); + load_tls(target); - DWORD old_protect; - VirtualProtect(PVOID(target_tls->StartAddressOfRawData), - source_tls->EndAddressOfRawData - source_tls->StartAddressOfRawData, PAGE_READWRITE, - &old_protect); + return target; + } - auto* const tls_base = *reinterpret_cast(__readgsqword(0x58) + 8ull * tls_index); - std::memmove(tls_base, PVOID(source_tls->StartAddressOfRawData), tls_size); - std::memmove(PVOID(target_tls->StartAddressOfRawData), PVOID(source_tls->StartAddressOfRawData), tls_size); - - VirtualProtect(target_tls, sizeof(*target_tls), PAGE_READWRITE, &old_protect); - *target_tls = *source_tls; + void set_import_resolver(const std::function& resolver) + { + import_resolver_ = resolver; } } diff --git a/src/client/loader/loader.hpp b/src/client/loader/loader.hpp index 0c0b5a12..ec135cf9 100644 --- a/src/client/loader/loader.hpp +++ b/src/client/loader/loader.hpp @@ -1,21 +1,9 @@ #pragma once #include -class loader final +namespace loader { -public: - FARPROC load(const utils::nt::library& library, const std::string& buffer) const; - FARPROC load_library(const std::string& filename, uint64_t* base_address) const; + utils::nt::library load_binary(const std::string& filename); - void set_import_resolver(const std::function& resolver); - -private: - std::function import_resolver_; - - static void load_section(const utils::nt::library& target, const utils::nt::library& source, - IMAGE_SECTION_HEADER* section); - void load_sections(const utils::nt::library& target, const utils::nt::library& source) const; - void load_imports(const utils::nt::library& target, const utils::nt::library& source) const; - void load_exception_table(const utils::nt::library& target, const utils::nt::library& source) const; - void load_tls(const utils::nt::library& target, const utils::nt::library& source) const; -}; + void set_import_resolver(const std::function& resolver); +} \ No newline at end of file diff --git a/src/client/loader/seh.cpp b/src/client/loader/seh.cpp deleted file mode 100644 index a3364a8f..00000000 --- a/src/client/loader/seh.cpp +++ /dev/null @@ -1,151 +0,0 @@ -#include - -#include -#include - -#include "seh.hpp" - -namespace seh -{ - namespace - { - void*(*rtlpx_lookup_function_table)(void*, FUNCTION_TABLE_DATA*); - void*(*rtlpx_lookup_function_table_down_level)(void*, PDWORD64, PULONG); - - FUNCTION_TABLE_DATA overridden_table; - - DWORD64 override_end; - DWORD64 override_start; - - void* find_call_from_address(void* method_ptr, ud_mnemonic_code mnemonic = UD_Icall) - { - ud_t ud; - ud_init(&ud); - ud_set_mode(&ud, 64); - ud_set_pc(&ud, reinterpret_cast(method_ptr)); - ud_set_input_buffer(&ud, static_cast(method_ptr), INT32_MAX); - - void* retval = nullptr; - while (true) - { - ud_disassemble(&ud); - - if (ud_insn_mnemonic(&ud) == UD_Iint3) break; - if (ud_insn_mnemonic(&ud) == mnemonic) - { - const auto* const operand = ud_insn_opr(&ud, 0); - if (operand->type == UD_OP_JIMM) - { - if (!retval) retval = reinterpret_cast(ud_insn_len(&ud) + ud_insn_off(&ud) + operand-> - lval.sdword); - else - { - retval = nullptr; - break; - } - } - } - } - - return retval; - } - - void* rtlpx_lookup_function_table_override(void* exception_address, FUNCTION_TABLE_DATA* out_data) - { - ZeroMemory(out_data, sizeof(*out_data)); - - auto* retval = seh::rtlpx_lookup_function_table(exception_address, out_data); - - const auto address_num = DWORD64(exception_address); - if (address_num >= seh::override_start && address_num <= seh::override_end) - { - if (address_num != 0) - { - *out_data = seh::overridden_table; - retval = PVOID(seh::overridden_table.TableAddress); - } - } - - return retval; - } - - void* rtlpx_lookup_function_table_override_down_level(void* exception_address, const PDWORD64 image_base, - const PULONG length) - { - auto* retval = seh::rtlpx_lookup_function_table_down_level(exception_address, image_base, length); - - const auto address_num = DWORD64(exception_address); - if (address_num >= seh::override_start && address_num <= seh::override_end) - { - if (address_num != 0) - { - *image_base = seh::overridden_table.ImageBase; - *length = seh::overridden_table.Size; - - retval = PVOID(seh::overridden_table.TableAddress); - } - } - - return retval; - } - } - - void setup_handler(void* module_base, void* module_end, PRUNTIME_FUNCTION runtime_functions, const DWORD entryCount) - { - const utils::nt::library ntdll("ntdll.dll"); - - seh::override_start = DWORD64(module_base); - seh::override_end = DWORD64(module_end); - - seh::overridden_table.ImageBase = seh::override_start; - seh::overridden_table.TableAddress = DWORD64(runtime_functions); - seh::overridden_table.Size = entryCount * sizeof(RUNTIME_FUNCTION); - - if (IsWindows8Point1OrGreater()) - { - struct - { - DWORD64 field0; - DWORD imageSize; - DWORD fieldC; - DWORD64 field10; - } query_result = {0, 0, 0, 0}; - - ntdll.invoke_pascal("NtQueryVirtualMemory", GetCurrentProcess(), module_base, 6, &query_result, - sizeof(query_result), nullptr); - seh::overridden_table.ImageSize = query_result.imageSize; - } - - auto* base_address = ntdll.get_proc("RtlLookupFunctionTable"); - auto* internal_address = find_call_from_address(base_address); - - void* patch_function = rtlpx_lookup_function_table_override; - auto** patch_original = reinterpret_cast(&seh::rtlpx_lookup_function_table); - - if (!internal_address) - { - if (!IsWindows8Point1OrGreater()) - { - internal_address = find_call_from_address(base_address, UD_Ijmp); - patch_function = rtlpx_lookup_function_table_override_down_level; - patch_original = reinterpret_cast(&seh::rtlpx_lookup_function_table_down_level); - } - - if (!internal_address) - { - if (IsWindows8OrGreater()) - { - // TODO: Catch the error - } - - internal_address = base_address; - patch_function = rtlpx_lookup_function_table_override_down_level; - patch_original = reinterpret_cast(&seh::rtlpx_lookup_function_table_down_level); - } - } - - static utils::hook::detour hook{}; - hook = utils::hook::detour(internal_address, patch_function); - *patch_original = hook.get_original(); - } -} diff --git a/src/client/loader/seh.hpp b/src/client/loader/seh.hpp deleted file mode 100644 index 232dca47..00000000 --- a/src/client/loader/seh.hpp +++ /dev/null @@ -1,38 +0,0 @@ -#pragma once - -struct FUNCTION_TABLE_DATA -{ - DWORD64 TableAddress; - DWORD64 ImageBase; - DWORD ImageSize; // field +8 in ZwQueryVirtualMemory class 6 - DWORD Size; -}; - -typedef enum _FUNCTION_TABLE_TYPE -{ - RF_SORTED, - RF_UNSORTED, - RF_CALLBACK -} FUNCTION_TABLE_TYPE; - -typedef struct _DYNAMIC_FUNCTION_TABLE -{ - LIST_ENTRY Links; - PRUNTIME_FUNCTION FunctionTable; - LARGE_INTEGER TimeStamp; - - ULONG_PTR MinimumAddress; - ULONG_PTR MaximumAddress; - ULONG_PTR BaseAddress; - - PGET_RUNTIME_FUNCTION_CALLBACK Callback; - PVOID Context; - PWSTR OutOfProcessCallbackDll; - FUNCTION_TABLE_TYPE Type; - ULONG EntryCount; -} DYNAMIC_FUNCTION_TABLE, *PDYNAMIC_FUNCTION_TABLE; - -namespace seh -{ - void setup_handler(void* module_base, void* module_end, PRUNTIME_FUNCTION runtime_functions, DWORD entryCount); -} diff --git a/src/client/loader/tls.cpp b/src/client/loader/tls.cpp index 16b84806..b781907b 100644 --- a/src/client/loader/tls.cpp +++ b/src/client/loader/tls.cpp @@ -28,7 +28,13 @@ namespace tls throw std::runtime_error("Failed to load TLS DLL"); } - return reinterpret_cast(tls_dll.get_ptr() + tls_dll.get_optional_header() - ->DataDirectory[IMAGE_DIRECTORY_ENTRY_TLS].VirtualAddress); + const auto tls_dir_entry = tls_dll.get_optional_header() + ->DataDirectory[IMAGE_DIRECTORY_ENTRY_TLS].VirtualAddress; + if (!tls_dir_entry) + { + throw std::runtime_error("TLS DLL is invalid"); + } + + return reinterpret_cast(tls_dll.get_ptr() + tls_dir_entry); } } diff --git a/src/client/main.cpp b/src/client/main.cpp index c92fac07..6f0dc9a7 100644 --- a/src/client/main.cpp +++ b/src/client/main.cpp @@ -1,12 +1,13 @@ #include #include "loader/loader.hpp" +#include "launcher/launcher.hpp" #include "loader/component_loader.hpp" #include "game/game.hpp" -#include "component/console/console.hpp" - #include #include +#include +#include DECLSPEC_NORETURN void WINAPI exit_hook(const int code) { @@ -20,14 +21,74 @@ DWORD_PTR WINAPI set_thread_affinity_mask(HANDLE hThread, DWORD_PTR dwThreadAffi return SetThreadAffinityMask(hThread, dwThreadAffinityMask); } -FARPROC load_binary(uint64_t* base_address) +launcher::mode detect_mode_from_arguments() +{ + if (utils::flags::has_flag("dedicated")) + { + return launcher::mode::server; + } + + if (utils::flags::has_flag("multiplayer")) + { + return launcher::mode::multiplayer; + } + + if (utils::flags::has_flag("singleplayer")) + { + return launcher::mode::singleplayer; + } + + return launcher::mode::none; +} + +void apply_aslr_patch(std::string* data) +{ + // sp binary, mp binary + if (data->size() != 0xE46800 && data->size() != 0x12EFA00) + { + printf("%llu", data->size()); + throw std::runtime_error("File size mismatch, bad game files"); + } + + auto* dos_header = reinterpret_cast(&data->at(0)); + auto* nt_headers = reinterpret_cast(&data->at(dos_header->e_lfanew)); + auto* optional_header = &nt_headers->OptionalHeader; + + if (optional_header->DllCharacteristics & IMAGE_DLLCHARACTERISTICS_DYNAMIC_BASE) + { + optional_header->DllCharacteristics &= ~(IMAGE_DLLCHARACTERISTICS_DYNAMIC_BASE); + } +} + +void get_aslr_patched_binary(std::string* binary, std::string* data) +{ + const auto patched_binary = (utils::properties::get_appdata_path() / "bin" / *binary).generic_string(); + + try + { + apply_aslr_patch(data); + if (!utils::io::file_exists(patched_binary) && !utils::io::write_file(patched_binary, *data, false)) + { + throw std::runtime_error("Could not write file"); + } + } + catch (const std::exception& e) + { + throw std::runtime_error( + utils::string::va("Could not create aslr patched binary for %s! %s", + binary->data(), e.what()) + ); + } + + *binary = patched_binary; +} + +FARPROC load_binary(const launcher::mode mode) { - loader loader; utils::nt::library self; - loader.set_import_resolver([self](const std::string& library, const std::string& function) -> void* + loader::set_import_resolver([self](const std::string& library, const std::string& function) -> void* { - if (function == "ExitProcess") { return exit_hook; @@ -40,22 +101,37 @@ FARPROC load_binary(uint64_t* base_address) return component_loader::load_import(library, function); }); - std::string binary = "s2_mp64_ship.exe"; + std::string binary; + switch (mode) + { + case launcher::mode::server: + case launcher::mode::multiplayer: + binary = "s2_mp64_ship.exe"; + break; + case launcher::mode::singleplayer: + binary = "s2_sp64_ship.exe"; + break; + case launcher::mode::none: + default: + throw std::runtime_error("Invalid game mode!"); + } std::string data; if (!utils::io::read_file(binary, &data)) { throw std::runtime_error(utils::string::va( - "Failed to read game binary (%s)!\nPlease copy the iw7-mod.exe into your Call of Duty: WWII installation folder and run it from there.", + "Failed to read game binary (%s)!\nPlease copy the s2-mod.exe into your Call of Duty: WWII installation folder and run it from there.", binary.data())); } -#ifdef INJECT_HOST_AS_LIB - return loader.load_library(binary, base_address); -#else - *base_address = 0x140000000; - return loader.load(self, data); // not working -#endif + get_aslr_patched_binary(&binary, &data); + + const auto proc = loader::load_binary(binary); + auto* const peb = reinterpret_cast(__readgsqword(0x60)); + peb->Reserved3[1] = proc.get_ptr(); + static_assert(offsetof(PEB, Reserved3[1]) == 0x10); + + return FARPROC(proc.get_ptr() + proc.get_relative_entry_point()); } void remove_crash_file() @@ -109,8 +185,7 @@ int main() FARPROC entry_point; enable_dpi_awareness(); - // This requires admin privilege, but I suppose many - // people will start with admin rights if it crashes. + // This requires admin privilege limit_parallel_dll_loading(); srand(uint32_t(time(nullptr))); @@ -132,21 +207,20 @@ int main() { if (!component_loader::post_start()) return EXIT_FAILURE; - uint64_t base_address{}; - entry_point = load_binary(&base_address); + auto mode = detect_mode_from_arguments(); + if (mode == launcher::mode::none) + { + const launcher launcher; + mode = launcher.run(); + if (mode == launcher::mode::none) return 0; + } + + entry_point = load_binary(mode); if (!entry_point) { throw std::runtime_error("Unable to load binary into memory"); } - if (base_address != 0x140000000) - { - throw std::runtime_error(utils::string::va( - "Base address was (%p) and not (%p)\nThis should not be possible!", - base_address, 0x140000000)); - } - game::base_address = base_address; - if (!component_loader::post_load()) return EXIT_FAILURE; premature_shutdown = false; diff --git a/src/common/utils/io.cpp b/src/common/utils/io.cpp index 2e38c6e5..66d64578 100644 --- a/src/common/utils/io.cpp +++ b/src/common/utils/io.cpp @@ -63,8 +63,8 @@ namespace utils::io if (size > -1) { - data->resize(static_cast(size)); - stream.read(const_cast(data->data()), size); + data->resize(static_cast(size)); + stream.read(data->data(), size); stream.close(); return true; } diff --git a/src/common/utils/nt.cpp b/src/common/utils/nt.cpp index fdb643d9..5be76334 100644 --- a/src/common/utils/nt.cpp +++ b/src/common/utils/nt.cpp @@ -12,7 +12,7 @@ namespace utils::nt return library::load(path.generic_string()); } - library library::get_by_address(const void* address) + library library::get_by_address(void* address) { HMODULE handle = nullptr; GetModuleHandleExA(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT, @@ -72,6 +72,11 @@ namespace utils::nt std::vector headers; auto nt_headers = this->get_nt_headers(); + if (!nt_headers) + { + return headers; + } + auto section = IMAGE_FIRST_SECTION(nt_headers); for (uint16_t i = 0; i < nt_headers->FileHeader.NumberOfSections; ++i, ++section) @@ -118,29 +123,28 @@ namespace utils::nt { if (!this->is_valid()) return ""; - auto path = this->get_path(); - const auto pos = path.find_last_of("/\\"); - if (pos == std::string::npos) return path; + const auto path = this->get_path(); + const auto pos = path.generic_string().find_last_of("/\\"); + if (pos == std::string::npos) return path.generic_string(); - return path.substr(pos + 1); + return path.generic_string().substr(pos + 1); } - std::string library::get_path() const + std::filesystem::path library::get_path() const { - if (!this->is_valid()) return ""; + if (!this->is_valid()) return {}; - char name[MAX_PATH] = { 0 }; - GetModuleFileNameA(this->module_, name, sizeof name); + wchar_t name[MAX_PATH] = { 0 }; + GetModuleFileNameW(this->module_, name, MAX_PATH); - return name; + return { name }; } - std::string library::get_folder() const + std::filesystem::path library::get_folder() const { - if (!this->is_valid()) return ""; + if (!this->is_valid()) return {}; - const auto path = std::filesystem::path(this->get_path()); - return path.parent_path().generic_string(); + return this->get_path().parent_path().generic_string(); } void library::free() @@ -157,7 +161,12 @@ namespace utils::nt return this->module_; } - void** library::get_iat_entry(const std::string& module_name, const std::string& proc_name) const + void** library::get_iat_entry(const std::string& module_name, std::string proc_name) const + { + return this->get_iat_entry(module_name, proc_name.data()); + } + + void** library::get_iat_entry(const std::string& module_name, const char* proc_name) const { if (!this->is_valid()) return nullptr; @@ -167,7 +176,7 @@ namespace utils::nt auto* const target_function = other_module.get_proc(proc_name); if (!target_function) return nullptr; - auto* header = this->get_optional_header(); + const auto* header = this->get_optional_header(); if (!header) return nullptr; auto* import_descriptor = reinterpret_cast(this->get_ptr() + header->DataDirectory @@ -184,7 +193,7 @@ namespace utils::nt while (original_thunk_data->u1.AddressOfData) { - if (thunk_data->u1.Function == (uint64_t)target_function) + if (thunk_data->u1.Function == reinterpret_cast(target_function)) { return reinterpret_cast(&thunk_data->u1.Function); } @@ -193,8 +202,8 @@ namespace utils::nt if (ordinal_number <= 0xFFFF) { - if (GetProcAddress(other_module.module_, reinterpret_cast(ordinal_number)) == - target_function) + auto* proc = GetProcAddress(other_module.module_, reinterpret_cast(ordinal_number)); + if (reinterpret_cast(proc) == target_function) { return reinterpret_cast(&thunk_data->u1.Function); } @@ -216,25 +225,14 @@ namespace utils::nt bool is_wine() { static const auto has_wine_export = []() -> bool - { - const library ntdll("ntdll.dll"); - return ntdll.get_proc("wine_get_version"); - }(); + { + const library ntdll("ntdll.dll"); + return ntdll.get_proc("wine_get_version"); + }(); return has_wine_export; } - bool is_shutdown_in_progress() - { - static auto* shutdown_in_progress = [] - { - const library ntdll("ntdll.dll"); - return ntdll.get_proc("RtlDllShutdownInProgress"); - }(); - - return shutdown_in_progress(); - } - void raise_hard_exception() { int data = false; @@ -254,7 +252,7 @@ namespace utils::nt return std::string(LPSTR(LockResource(handle)), SizeofResource(nullptr, res)); } - void relaunch_self() + void relaunch_self(const std::string& extra_command_line, bool override_command_line) { const utils::nt::library self; @@ -267,9 +265,21 @@ namespace utils::nt char current_dir[MAX_PATH]; GetCurrentDirectoryA(sizeof(current_dir), current_dir); - auto* const command_line = GetCommandLineA(); - CreateProcessA(self.get_path().data(), command_line, nullptr, nullptr, false, + std::string command_line = GetCommandLineA(); + if (!extra_command_line.empty()) + { + if (override_command_line) + { + command_line = extra_command_line; + } + else + { + command_line += " " + extra_command_line; + } + } + + CreateProcessA(self.get_path().generic_string().data(), command_line.data(), nullptr, nullptr, false, CREATE_NEW_CONSOLE, nullptr, current_dir, &startup_info, &process_info); if (process_info.hThread && process_info.hThread != INVALID_HANDLE_VALUE) CloseHandle(process_info.hThread); @@ -279,5 +289,6 @@ namespace utils::nt void terminate(const uint32_t code) { TerminateProcess(GetCurrentProcess(), code); + _Exit(code); } -} +} \ No newline at end of file diff --git a/src/common/utils/nt.hpp b/src/common/utils/nt.hpp index 14c7397b..d2675965 100644 --- a/src/common/utils/nt.hpp +++ b/src/common/utils/nt.hpp @@ -23,7 +23,7 @@ namespace utils::nt public: static library load(const std::string& name); static library load(const std::filesystem::path& path); - static library get_by_address(const void* address); + static library get_by_address(void* address); library(); explicit library(const std::string& name); @@ -40,23 +40,29 @@ namespace utils::nt operator HMODULE() const; void unprotect() const; - void* get_entry_point() const; - size_t get_relative_entry_point() const; + [[nodiscard]] void* get_entry_point() const; + [[nodiscard]] size_t get_relative_entry_point() const; - bool is_valid() const; - std::string get_name() const; - std::string get_path() const; - std::string get_folder() const; - std::uint8_t* get_ptr() const; + [[nodiscard]] bool is_valid() const; + [[nodiscard]] std::string get_name() const; + [[nodiscard]] std::filesystem::path get_path() const; + [[nodiscard]] std::filesystem::path get_folder() const; + [[nodiscard]] std::uint8_t* get_ptr() const; void free(); - HMODULE get_handle() const; + [[nodiscard]] HMODULE get_handle() const; template - T get_proc(const std::string& process) const + [[nodiscard]] T get_proc(const char* process) const { if (!this->is_valid()) T{}; - return reinterpret_cast(GetProcAddress(this->module_, process.data())); + return reinterpret_cast(GetProcAddress(this->module_, process)); + } + + template + [[nodiscard]] T get_proc(const std::string& process) const + { + return get_proc(process.data()); } template @@ -90,88 +96,24 @@ namespace utils::nt return T(); } - std::vector get_section_headers() const; + [[nodiscard]] std::vector get_section_headers() const; - PIMAGE_NT_HEADERS get_nt_headers() const; - PIMAGE_DOS_HEADER get_dos_header() const; - PIMAGE_OPTIONAL_HEADER get_optional_header() const; + [[nodiscard]] PIMAGE_NT_HEADERS get_nt_headers() const; + [[nodiscard]] PIMAGE_DOS_HEADER get_dos_header() const; + [[nodiscard]] PIMAGE_OPTIONAL_HEADER get_optional_header() const; - void** get_iat_entry(const std::string& module_name, const std::string& proc_name) const; + [[nodiscard]] void** get_iat_entry(const std::string& module_name, std::string proc_name) const; + [[nodiscard]] void** get_iat_entry(const std::string& module_name, const char* proc_name) const; private: HMODULE module_; }; - template - class handle - { - public: - handle() = default; - - handle(const HANDLE h) - : handle_(h) - { - } - - ~handle() - { - if (*this) - { - CloseHandle(this->handle_); - this->handle_ = InvalidHandle; - } - } - - handle(const handle&) = delete; - handle& operator=(const handle&) = delete; - - handle(handle&& obj) noexcept - : handle() - { - this->operator=(std::move(obj)); - } - - handle& operator=(handle&& obj) noexcept - { - if (this != &obj) - { - this->~handle(); - this->handle_ = obj.handle_; - obj.handle_ = InvalidHandle; - } - - return *this; - } - - handle& operator=(HANDLE h) noexcept - { - this->~handle(); - this->handle_ = h; - - return *this; - } - - operator bool() const - { - return this->handle_ != InvalidHandle; - } - - operator HANDLE() const - { - return this->handle_; - } - - private: - HANDLE handle_{ InvalidHandle }; - }; - bool is_wine(); - bool is_shutdown_in_progress(); - __declspec(noreturn) void raise_hard_exception(); std::string load_resource(int id); - void relaunch_self(); + void relaunch_self(const std::string& extra_command_line = "", bool override_command_line = false); __declspec(noreturn) void terminate(uint32_t code = 0); -} +} \ No newline at end of file