diff --git a/src/game/game.hpp b/src/game/game.hpp index 83585d1..4559c75 100644 --- a/src/game/game.hpp +++ b/src/game/game.hpp @@ -17,7 +17,7 @@ namespace game typedef void (*Cmd_RemoveCommand_t)(const char* cmdName); extern Cmd_RemoveCommand_t Cmd_RemoveCommand; - typedef void (*Com_Error_t)(int code, const char* fmt, ...); + typedef void (*Com_Error_t)(errorParm_t code, const char* fmt, ...); extern Com_Error_t Com_Error; typedef void (*DB_LoadXAssets_t)(XZoneInfo* zoneInfo, unsigned int zoneCount, int sync); diff --git a/src/game/scripting/event_handler.cpp b/src/game/scripting/event_handler.cpp index 16454a2..322bc20 100644 --- a/src/game/scripting/event_handler.cpp +++ b/src/game/scripting/event_handler.cpp @@ -1,4 +1,5 @@ #include + #include "context.hpp" namespace game::scripting @@ -13,10 +14,10 @@ namespace game::scripting "_event_listener_handle"); chai->add(chaiscript::fun( - [](event_listener_handle& lhs, const event_listener_handle& rhs) -> event_listener_handle& - { - return lhs = rhs; - }), "="); + [](event_listener_handle& lhs, const event_listener_handle& rhs) -> event_listener_handle& + { + return lhs = rhs; + }), "="); chai->add(chaiscript::fun([this](const event_listener_handle& handle) { @@ -45,77 +46,112 @@ namespace game::scripting } void event_handler::dispatch_to_specific_listeners(event* event, - const std::vector& arguments) + const std::vector& arguments) { - for (auto listener : this->event_listeners_) + this->event_listeners_.access([&](task_list& tasks) { - if (listener->event == event->name && listener->entity_id == event->entity_id) + for (auto listener = tasks.begin(); listener != tasks.end();) { - if (listener->is_volatile) + if (listener->event == event->name && listener->entity_id == event->entity_id) { - this->event_listeners_.remove(listener); + if (listener->is_volatile) + { + listener = tasks.erase(listener); + continue; + } + + listener->callback(arguments); } - listener->callback(arguments); + ++listener; } - } + }); } void event_handler::dispatch_to_generic_listeners(event* event, - const std::vector& arguments) + const std::vector& arguments) { - for (auto listener : this->generic_event_listeners_) + this->generic_event_listeners_.access([&](generic_task_list& tasks) { - if (listener->event == event->name) + for (auto listener = tasks.begin(); listener != tasks.end();) { - if (listener->is_volatile) + if (listener->event == event->name) { - this->generic_event_listeners_.remove(listener); + if (listener->is_volatile) + { + listener = tasks.erase(listener); + continue; + } + + listener->callback(entity(this->context_, event->entity_id), arguments); } - listener->callback(entity(this->context_, event->entity_id), arguments); + ++listener; } - } + }); } event_listener_handle event_handler::add_event_listener(event_listener listener) { listener.id = ++this->current_listener_id_; - this->event_listeners_.add(listener); + this->event_listeners_.access([listener](task_list& tasks) + { + tasks.emplace_back(std::move(listener)); + }); return {listener.id}; } event_listener_handle event_handler::add_event_listener(generic_event_listener listener) { listener.id = ++this->current_listener_id_; - this->generic_event_listeners_.add(listener); + this->generic_event_listeners_.access([listener](generic_task_list& tasks) + { + tasks.emplace_back(std::move(listener)); + }); return {listener.id}; } void event_handler::clear() { - this->event_listeners_.clear(); - this->generic_event_listeners_.clear(); + this->event_listeners_.access([](task_list& tasks) + { + tasks.clear(); + }); + + this->generic_event_listeners_.access([](generic_task_list& tasks) + { + tasks.clear(); + }); } void event_handler::remove(const event_listener_handle& handle) { - for (const auto task : this->event_listeners_) + this->event_listeners_.access([handle](task_list& tasks) { - if (task->id == handle.id) + for (auto i = tasks.begin(); i != tasks.end();) { - this->event_listeners_.remove(task); - return; - } - } + if (i->id == handle.id) + { + i = tasks.erase(i); + return; + } - for (const auto task : this->generic_event_listeners_) - { - if (task->id == handle.id) - { - this->generic_event_listeners_.remove(task); - return; + ++i; } - } + }); + + this->generic_event_listeners_.access([handle](generic_task_list& tasks) + { + for (auto i = tasks.begin(); i != tasks.end();) + { + if (i->id == handle.id) + { + i = tasks.erase(i); + return; + } + + ++i; + } + }); } } diff --git a/src/game/scripting/event_handler.hpp b/src/game/scripting/event_handler.hpp index 63d3c6f..aab1317 100644 --- a/src/game/scripting/event_handler.hpp +++ b/src/game/scripting/event_handler.hpp @@ -1,5 +1,5 @@ #pragma once -#include "utils/concurrent_list.hpp" +#include #include "entity.hpp" #include "event.hpp" @@ -46,8 +46,11 @@ namespace game::scripting context* context_; std::atomic_int64_t current_listener_id_ = 0; - utils::concurrent_list event_listeners_; - utils::concurrent_list generic_event_listeners_; + using task_list = std::vector; + utils::concurrency::container event_listeners_; + + using generic_task_list = std::vector; + utils::concurrency::container generic_event_listeners_; void dispatch_to_specific_listeners(event* event, const std::vector& arguments); void dispatch_to_generic_listeners(event* event, const std::vector& arguments); diff --git a/src/game/scripting/scheduler.cpp b/src/game/scripting/scheduler.cpp index 2c601d9..1a4c169 100644 --- a/src/game/scripting/scheduler.cpp +++ b/src/game/scripting/scheduler.cpp @@ -1,4 +1,5 @@ #include + #include "context.hpp" namespace game::scripting @@ -17,16 +18,16 @@ namespace game::scripting }), "="); chai->add(chaiscript::fun( - [this](const std::function& callback, const long long milliseconds) -> task_handle - { - return this->add(callback, milliseconds, true); - }), "setTimeout"); + [this](const std::function& callback, const long long milliseconds) -> task_handle + { + return this->add(callback, milliseconds, true); + }), "setTimeout"); chai->add(chaiscript::fun( - [this](const std::function& callback, const long long milliseconds) -> task_handle - { - return this->add(callback, milliseconds, false); - }), "setInterval"); + [this](const std::function& callback, const long long milliseconds) -> task_handle + { + return this->add(callback, milliseconds, false); + }), "setInterval"); const auto clear = [this](const task_handle& handle) { @@ -40,25 +41,40 @@ namespace game::scripting void scheduler::run_frame() { - for (auto task : this->tasks_) + this->tasks_.access([&](task_list& tasks) { - const auto now = std::chrono::steady_clock::now(); - if ((now - task->last_execution) > task->delay) + for (auto i = tasks.begin(); i != tasks.end();) { - task->last_execution = now; - if (task->is_volatile) + const auto now = std::chrono::steady_clock::now(); + const auto diff = now - i->last_execution; + + if (diff < i->delay) { - this->tasks_.remove(task); + ++i; + continue; } - task->callback(); + i->last_execution = now; + + if (i->is_volatile) + { + i = tasks.erase(i); + } + else + { + i->callback(); + ++i; + } } - } + }); } void scheduler::clear() { - this->tasks_.clear(); + this->tasks_.access([&](task_list& tasks) + { + tasks.clear(); + }); } task_handle scheduler::add(const std::function& callback, const long long milliseconds, @@ -77,20 +93,28 @@ namespace game::scripting task.last_execution = std::chrono::steady_clock::now(); task.id = ++this->current_task_id_; - this->tasks_.add(task); + this->tasks_.access([&task](task_list& tasks) + { + tasks.emplace_back(std::move(task)); + }); return {task.id}; } void scheduler::remove(const task_handle& handle) { - for (auto task : this->tasks_) + this->tasks_.access([&](task_list& tasks) { - if (task->id == handle.id) + for (auto i = tasks.begin(); i != tasks.end();) { - this->tasks_.remove(task); - break; + if (i->id == handle.id) + { + i = tasks.erase(i); + break; + } + + ++i; } - } + }); } } diff --git a/src/game/scripting/scheduler.hpp b/src/game/scripting/scheduler.hpp index e92348e..3783b66 100644 --- a/src/game/scripting/scheduler.hpp +++ b/src/game/scripting/scheduler.hpp @@ -1,5 +1,5 @@ #pragma once -#include "utils/concurrent_list.hpp" +#include namespace game::scripting { @@ -8,7 +8,7 @@ namespace game::scripting class task_handle { public: - unsigned long long id = 0; + std::uint64_t id = 0; }; class task final : public task_handle @@ -31,7 +31,8 @@ namespace game::scripting private: context* context_; - utils::concurrent_list tasks_; + using task_list = std::vector; + utils::concurrency::container tasks_; std::atomic_int64_t current_task_id_ = 0; task_handle add(const std::function& callback, long long milliseconds, bool is_volatile); diff --git a/src/game/structs.hpp b/src/game/structs.hpp index d72ca9d..f7b0125 100644 --- a/src/game/structs.hpp +++ b/src/game/structs.hpp @@ -374,6 +374,18 @@ namespace game }; #pragma pack(pop) + enum errorParm_t + { + ERR_FATAL = 0x0, + ERR_DROP = 0x1, + ERR_SERVERDISCONNECT = 0x2, + ERR_DISCONNECT = 0x3, + ERR_SCRIPT = 0x4, + ERR_SCRIPT_DROP = 0x5, + ERR_LOCALIZATION = 0x6, + ERR_MAPLOADERRORSUMMARY = 0x7 + }; + enum svscmd_type { SV_CMD_CAN_IGNORE, diff --git a/src/module/console.cpp b/src/module/console.cpp index 920aa54..eccb500 100644 --- a/src/module/console.cpp +++ b/src/module/console.cpp @@ -21,7 +21,8 @@ public: void post_start() override { - scheduler::on_frame(std::bind(&console::log_messages, this)); + scheduler::loop(std::bind(&console::log_messages, this), scheduler::pipeline::main); + this->console_runner_ = std::thread(std::bind(&console::runner, this)); } diff --git a/src/module/discord.cpp b/src/module/discord.cpp index e76a8a8..1cbdf68 100644 --- a/src/module/discord.cpp +++ b/src/module/discord.cpp @@ -24,7 +24,7 @@ public: Discord_Initialize("531526691319971880", &handlers, 1, nullptr); - scheduler::on_frame(Discord_RunCallbacks); + scheduler::loop(Discord_RunCallbacks, scheduler::pipeline::main); } void pre_destroy() override diff --git a/src/module/scheduler.cpp b/src/module/scheduler.cpp index e3b3eba..7c87997 100644 --- a/src/module/scheduler.cpp +++ b/src/module/scheduler.cpp @@ -1,103 +1,181 @@ #include #include -#include + #include +#include +#include #include "game/game.hpp" #include "scheduler.hpp" -std::mutex scheduler::mutex_; -std::queue> scheduler::errors_; -utils::concurrent_list> scheduler::callbacks_; -utils::concurrent_list> scheduler::single_callbacks_; - -void scheduler::on_frame(const std::function& callback) +namespace { - std::lock_guard _(mutex_); - callbacks_.add(callback); + constexpr bool cond_continue = false; + constexpr bool cond_end = true; + + struct task + { + std::function handler{}; + std::chrono::milliseconds interval{}; + std::chrono::high_resolution_clock::time_point last_call{}; + }; + + using task_list = std::vector; + + class task_pipeline + { + public: + void add(task&& task) + { + new_callbacks_.access([&task](task_list& tasks) + { + tasks.emplace_back(std::move(task)); + }); + } + + void execute() + { + callbacks_.access([&](task_list& tasks) + { + this->merge_callbacks(); + + for (auto i = tasks.begin(); i != tasks.end();) + { + const auto now = std::chrono::high_resolution_clock::now(); + const auto diff = now - i->last_call; + + if (diff < i->interval) + { + ++i; + continue; + } + + i->last_call = now; + + const auto res = i->handler(); + if (res == cond_end) + { + i = tasks.erase(i); + } + else + { + ++i; + } + } + }); + } + + private: + utils::concurrency::container new_callbacks_; + utils::concurrency::container callbacks_; + + void merge_callbacks() + { + callbacks_.access([&](task_list& tasks) + { + new_callbacks_.access([&](task_list& new_tasks) + { + tasks.insert(tasks.end(), std::move_iterator(new_tasks.begin()), std::move_iterator(new_tasks.end())); + new_tasks = {}; + }); + }); + } + }; + + volatile bool kill = false; + std::thread thread; + task_pipeline pipelines[scheduler::pipeline::count]; } -void scheduler::once(const std::function& callback) +void scheduler::execute(const pipeline type) { - std::lock_guard _(mutex_); - single_callbacks_.add(callback); + assert(type >= 0 && type < pipeline::count); + pipelines[type].execute(); } -void scheduler::error(const std::string& message, int level) +void scheduler::r_end_frame_stub() { - std::lock_guard _(mutex_); - errors_.emplace(message, level); + reinterpret_cast(SELECT_VALUE(0x4193D0, 0x67F840, 0x0))(); + execute(pipeline::renderer); } -void scheduler::frame_stub() +void scheduler::g_glass_update_stub() +{ + reinterpret_cast(SELECT_VALUE(0x4E3730, 0x505BB0, 0x481EA0))(); + execute(pipeline::server); +} + +void scheduler::main_frame_stub() { - execute(); reinterpret_cast(SELECT_VALUE(0x458600, 0x556470, 0x4DB070))(); + execute(pipeline::main); } -__declspec(naked) void scheduler::execute() +void scheduler::schedule(const std::function& callback, const pipeline type, + const std::chrono::milliseconds delay) { - __asm - { - call execute_error - call execute_safe - retn - } + assert(type >= 0 && type < pipeline::count); + + task task; + task.handler = callback; + task.interval = delay; + task.last_call = std::chrono::high_resolution_clock::now(); + + pipelines[type].add(std::move(task)); } -void scheduler::execute_safe() +void scheduler::loop(const std::function& callback, const pipeline type, + const std::chrono::milliseconds delay) { - for (auto callback : callbacks_) + schedule([callback]() { - (*callback)(); - } - - for (auto callback : single_callbacks_) - { - single_callbacks_.remove(callback); - (*callback)(); - } + callback(); + return cond_continue; + }, type, delay); } -void scheduler::execute_error() +void scheduler::once(const std::function& callback, const pipeline type, + const std::chrono::milliseconds delay) { - const char* message = nullptr; - int level = 0; - - if (get_next_error(&message, &level) && message) + schedule([callback]() { - game::native::Com_Error(level, "%s", message); - } + callback(); + return cond_end; + }, type, delay); } -bool scheduler::get_next_error(const char** error_message, int* error_level) +void scheduler::post_start() { - std::lock_guard _(mutex_); - if (errors_.empty()) + thread = utils::thread::create_named_thread("Async Scheduler", []() { - *error_message = nullptr; - return false; - } - - const auto error = errors_.front(); - errors_.pop(); - - *error_level = error.second; - *error_message = utils::string::va("%s", error.first.data()); - - return true; + while (!kill) + { + execute(pipeline::async); + std::this_thread::sleep_for(10ms); + } + }); } void scheduler::post_load() { - utils::hook(SELECT_VALUE(0x44C7DB, 0x55688E, 0x4DB324), frame_stub, HOOK_CALL).install()->quick(); + utils::hook(SELECT_VALUE(0x44C7DB, 0x55688E, 0x4DB324), main_frame_stub, HOOK_CALL).install()->quick(); + + if (!game::is_dedi()) + { + utils::hook(SELECT_VALUE(0x57F7F8, 0x4978E2, 0x0), r_end_frame_stub, HOOK_CALL).install()->quick(); + } + + // Hook a function inside G_RunFrame. Fixes TLS issues + utils::hook(SELECT_VALUE(0x52EFBC, 0x50CEC6, 0x48B277), g_glass_update_stub, HOOK_CALL).install()->quick(); } void scheduler::pre_destroy() { - std::lock_guard _(mutex_); - callbacks_.clear(); - single_callbacks_.clear(); + kill = true; + if (thread.joinable()) + { + thread.join(); + } } REGISTER_MODULE(scheduler); diff --git a/src/module/scheduler.hpp b/src/module/scheduler.hpp index a249c90..c1e08d3 100644 --- a/src/module/scheduler.hpp +++ b/src/module/scheduler.hpp @@ -1,27 +1,40 @@ #pragma once -#include "utils/concurrent_list.hpp" class scheduler final : public module { public: - static void on_frame(const std::function& callback); - static void once(const std::function& callback); + enum pipeline + { + // Asynchronuous pipeline, disconnected from the game + async = 0, - static void error(const std::string& message, int level); + // The game's rendering pipeline + renderer, + // The game's server thread + server, + + // The game's main thread + main, + + count, + }; + + void post_start() override; void post_load() override; void pre_destroy() override; + static void schedule(const std::function& callback, pipeline type = pipeline::async, + std::chrono::milliseconds delay = 0ms); + static void loop(const std::function& callback, pipeline type = pipeline::async, + std::chrono::milliseconds delay = 0ms); + static void once(const std::function& callback, pipeline type = pipeline::async, + std::chrono::milliseconds delay = 0ms); + private: - static std::mutex mutex_; - static std::queue> errors_; - static utils::concurrent_list> callbacks_; - static utils::concurrent_list> single_callbacks_; + static void execute(const pipeline type); - static void frame_stub(); - - static void execute(); - static void execute_safe(); - static void execute_error(); - static bool get_next_error(const char** error_message, int* error_level); + static void r_end_frame_stub(); + static void g_glass_update_stub(); + static void main_frame_stub(); }; diff --git a/src/module/scripting.cpp b/src/module/scripting.cpp index 8d3b263..b3a8420 100644 --- a/src/module/scripting.cpp +++ b/src/module/scripting.cpp @@ -48,7 +48,7 @@ private: { const auto script_dir = "open-iw5/scripts/"s; - if(!utils::io::directory_exists(script_dir)) + if (!utils::io::directory_exists(script_dir)) { return; } @@ -115,7 +115,10 @@ private: printf("%s\n", e.what()); printf("**************************************\n\n"); - scheduler::error("Script execution error\n(see console for actual details)\n", 5); + scheduler::once([] + { + game::native::Com_Error(game::native::errorParm_t::ERR_SCRIPT, "Script execution error\n(see console for actual details)\n"); + }, scheduler::pipeline::main); } static void start_execution_stub() diff --git a/src/module/test_clients.cpp b/src/module/test_clients.cpp index 092b07a..0e08ceb 100644 --- a/src/module/test_clients.cpp +++ b/src/module/test_clients.cpp @@ -1,10 +1,12 @@ #include #include #include +#include #include "game/game.hpp" #include "test_clients.hpp" #include "command.hpp" +#include "scheduler.hpp" bool test_clients::can_add() { @@ -91,7 +93,7 @@ game::native::gentity_s* test_clients::sv_add_test_client() return client->gentity; } -void test_clients::spawn() +void test_clients::gscr_add_test_client() { const auto* ent = test_clients::sv_add_test_client(); @@ -101,6 +103,35 @@ void test_clients::spawn() } } +void test_clients::spawn(const int count) +{ + for (int i = 0; i < count; ++i) + { + scheduler::once([]() + { + auto* ent = sv_add_test_client(); + if (ent == nullptr) return; + + game::native::Scr_AddEntityNum(ent->s.number, 0); + scheduler::once([ent]() + { + game::native::Scr_AddString("autoassign"); + game::native::Scr_AddString("team_marinesopfor"); + game::native::Scr_Notify(ent, static_cast(game::native::SL_GetString("menuresponse", 0)), 2); + + scheduler::once([ent]() + { + game::native::Scr_AddString(utils::string::va("class%i", std::rand() % 5)); + game::native::Scr_AddString("changeclass"); + game::native::Scr_Notify(ent, static_cast(game::native::SL_GetString("menuresponse", 0)), 2); + }, scheduler::pipeline::server, 2s); + + }, scheduler::pipeline::server, 1s); + + }, scheduler::pipeline::server, 2s * (i + 1)); + } +} + void test_clients::scr_shutdown_system_mp_stub(unsigned char sys) { game::native::SV_DropAllBots(); @@ -158,11 +189,15 @@ void test_clients::post_load() if (game::is_mp()) this->patch_mp(); else return; // No sp/dedi bots for now :( - command::add("spawnBot", []() + command::add("spawnBot", [](const command::params& params) { - // Because I am unable to expand the scheduler at the moment - // we only get one bot at the time - test_clients::spawn(); + if (params.size() < 2) + { + return; + } + + const auto count = std::atoi(params.get(1)); + test_clients::spawn(count); }); } @@ -180,7 +215,7 @@ void test_clients::patch_mp() utils::hook(0x576DCC, &test_clients::check_timeouts_stub_mp, HOOK_JUMP).install()->quick(); // SV_CheckTimeouts // Replace nullsubbed gsc func "GScr_AddTestClient" with our spawn - utils::hook::set(0x8AC8DC, test_clients::spawn); + utils::hook::set(0x8AC8DC, test_clients::gscr_add_test_client); } REGISTER_MODULE(test_clients); diff --git a/src/module/test_clients.hpp b/src/module/test_clients.hpp index 81e1a40..6bdff24 100644 --- a/src/module/test_clients.hpp +++ b/src/module/test_clients.hpp @@ -10,7 +10,8 @@ private: static bool can_add(); static game::native::gentity_s* sv_add_test_client(); - static void spawn(); + static void gscr_add_test_client(); + static void spawn(int count); static void scr_shutdown_system_mp_stub(unsigned char sys); diff --git a/src/utils/concurrency.hpp b/src/utils/concurrency.hpp new file mode 100644 index 0000000..05c5d3a --- /dev/null +++ b/src/utils/concurrency.hpp @@ -0,0 +1,46 @@ +#pragma once + +#include + +namespace utils::concurrency +{ + template + class container + { + public: + template + R access(F&& accessor) const + { + std::lock_guard _{mutex_}; + return accessor(object_); + } + + template + R access(F&& accessor) + { + std::lock_guard _{mutex_}; + return accessor(object_); + } + + template + R access_with_lock(F&& accessor) const + { + std::unique_lock lock{mutex_}; + return accessor(object_, lock); + } + + template + R access_with_lock(F&& accessor) + { + std::unique_lock lock{mutex_}; + return accessor(object_, lock); + } + + T& get_raw() { return object_; } + const T& get_raw() const { return object_; } + + private: + mutable MutexType mutex_{}; + T object_{}; + }; +} diff --git a/src/utils/concurrent_list.hpp b/src/utils/concurrent_list.hpp deleted file mode 100644 index 135f20f..0000000 --- a/src/utils/concurrent_list.hpp +++ /dev/null @@ -1,131 +0,0 @@ -#pragma once - -namespace utils -{ - template - class concurrent_list final - { - public: - class element final - { - public: - explicit element(std::recursive_mutex* mutex, std::shared_ptr entry = {}, - std::shared_ptr next = {}) : - mutex_(mutex), - entry_(std::move(entry)), - next_(std::move(next)) - { - } - - void remove(const std::shared_ptr& element) - { - std::lock_guard _(*this->mutex_); - if (!this->next_) return; - - if (this->next_->entry_.get() == element.get()) - { - this->next_ = this->next_->next_; - } - else - { - this->next_->remove(element); - } - } - - std::shared_ptr get_next() const - { - std::lock_guard _(*this->mutex_); - return this->next_; - } - - std::shared_ptr operator*() const - { - std::lock_guard _(*this->mutex_); - return this->entry_; - } - - element& operator++() - { - std::lock_guard _(*this->mutex_); - *this = this->next_ ? *this->next_ : element(this->mutex_); - return *this; - } - - element operator++(int) - { - std::lock_guard _(*this->mutex_); - auto result = *this; - this->operator++(); - return result; - } - - bool operator==(const element& other) - { - std::lock_guard _(*this->mutex_); - return this->entry_.get() == other.entry_.get(); - } - - bool operator!=(const element& other) - { - std::lock_guard _(*this->mutex_); - return !(*this == other); - } - - private: - std::recursive_mutex* mutex_; - std::shared_ptr entry_; - std::shared_ptr next_; - }; - - element begin() - { - std::lock_guard _(this->mutex_); - return this->entry_ ? *this->entry_ : this->end(); - } - - element end() - { - std::lock_guard _(this->mutex_); - return element(&this->mutex_); - } - - void remove(const element& entry) - { - std::lock_guard _(this->mutex_); - this->remove(*entry); - } - - void remove(const std::shared_ptr& element) - { - std::lock_guard _(this->mutex_); - if (!this->entry_) return; - - if ((**this->entry_).get() == element.get()) - { - this->entry_ = this->entry_->get_next(); - } - else - { - this->entry_->remove(element); - } - } - - void add(const T& object) - { - std::lock_guard _(this->mutex_); - - const auto object_ptr = std::make_shared(object); - this->entry_ = std::make_shared(&this->mutex_, object_ptr, this->entry_); - } - - void clear() - { - std::lock_guard _(this->mutex_); - this->entry_ = {}; - } - - private: - std::recursive_mutex mutex_; - std::shared_ptr entry_; - }; -} diff --git a/src/utils/string.cpp b/src/utils/string.cpp index e22f52d..50abd4f 100644 --- a/src/utils/string.cpp +++ b/src/utils/string.cpp @@ -36,6 +36,19 @@ namespace utils::string return text; } + std::wstring convert(const std::string& str) + { + std::wstring result; + result.reserve(str.size()); + + for (const auto& chr : str) + { + result.push_back(static_cast(chr)); + } + + return result; + } + std::string dump_hex(const std::string& data, const std::string& separator) { std::string result; diff --git a/src/utils/string.hpp b/src/utils/string.hpp index d7776d4..5874629 100644 --- a/src/utils/string.hpp +++ b/src/utils/string.hpp @@ -77,5 +77,7 @@ namespace utils::string std::string to_lower(std::string text); std::string to_upper(std::string text); + std::wstring convert(const std::string& str); + std::string dump_hex(const std::string& data, const std::string& separator = " "); } diff --git a/src/utils/thread.cpp b/src/utils/thread.cpp new file mode 100644 index 0000000..5840507 --- /dev/null +++ b/src/utils/thread.cpp @@ -0,0 +1,130 @@ +#include + +#include "thread.hpp" +#include "string.hpp" +#include "nt.hpp" + +#include + +#include + +namespace utils::thread +{ + bool set_name(const HANDLE t, const std::string& name) + { + const nt::library kernel32("kernel32.dll"); + if (!kernel32) + { + return false; + } + + const auto set_description = kernel32.get_proc("SetThreadDescription"); + if (!set_description) + { + return false; + } + + return SUCCEEDED(set_description(t, string::convert(name).data())); + } + + bool set_name(const DWORD id, const std::string& name) + { + auto* const t = OpenThread(THREAD_SET_LIMITED_INFORMATION, FALSE, id); + if (!t) return false; + + const auto _ = gsl::finally([t]() + { + CloseHandle(t); + }); + + return set_name(t, name); + } + + bool set_name(std::thread& t, const std::string& name) + { + return set_name(t.native_handle(), name); + } + + bool set_name(const std::string& name) + { + return set_name(GetCurrentThread(), name); + } + + std::vector get_thread_ids() + { + auto* const h = CreateToolhelp32Snapshot(TH32CS_SNAPTHREAD, GetCurrentProcessId()); + if (h == INVALID_HANDLE_VALUE) + { + return {}; + } + + const auto _ = gsl::finally([h]() + { + CloseHandle(h); + }); + + THREADENTRY32 entry{}; + entry.dwSize = sizeof(entry); + if (!Thread32First(h, &entry)) + { + return {}; + } + + std::vector ids{}; + + do + { + const auto check_size = entry.dwSize < FIELD_OFFSET(THREADENTRY32, th32OwnerProcessID) + + sizeof(entry.th32OwnerProcessID); + entry.dwSize = sizeof(entry); + + if (check_size && entry.th32OwnerProcessID == GetCurrentProcessId()) + { + ids.emplace_back(entry.th32ThreadID); + } + } while (Thread32Next(h, &entry)); + + return ids; + } + + void for_each_thread(const std::function& callback) + { + const auto ids = get_thread_ids(); + + for (const auto& id : ids) + { + auto* const thread = OpenThread(THREAD_ALL_ACCESS, FALSE, id); + if (thread != nullptr) + { + const auto _ = gsl::finally([thread]() + { + CloseHandle(thread); + }); + + callback(thread); + } + } + } + + void suspend_other_threads() + { + for_each_thread([](const HANDLE thread) + { + if (GetThreadId(thread) != GetCurrentThreadId()) + { + SuspendThread(thread); + } + }); + } + + void resume_other_threads() + { + for_each_thread([](const HANDLE thread) + { + if (GetThreadId(thread) != GetCurrentThreadId()) + { + ResumeThread(thread); + } + }); + } +} diff --git a/src/utils/thread.hpp b/src/utils/thread.hpp new file mode 100644 index 0000000..e7fcbd0 --- /dev/null +++ b/src/utils/thread.hpp @@ -0,0 +1,24 @@ +#pragma once +#include + +namespace utils::thread +{ + bool set_name(HANDLE t, const std::string& name); + bool set_name(DWORD id, const std::string& name); + bool set_name(std::thread& t, const std::string& name); + bool set_name(const std::string& name); + + template + std::thread create_named_thread(const std::string& name, Args&&... args) + { + auto t = std::thread(std::forward(args)...); + set_name(t, name); + return t; + } + + std::vector get_thread_ids(); + void for_each_thread(const std::function& callback); + + void suspend_other_threads(); + void resume_other_threads(); +}