Merge pull request #77 from diamante0018/master

Scheduler refactoring [Fix spawnbot]
This commit is contained in:
Maurice Heumann 2022-04-25 17:16:49 +02:00 committed by GitHub
commit feb5cf67c1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 571 additions and 280 deletions

View File

@ -17,7 +17,7 @@ namespace game
typedef void (*Cmd_RemoveCommand_t)(const char* cmdName);
extern Cmd_RemoveCommand_t Cmd_RemoveCommand;
typedef void (*Com_Error_t)(int code, const char* fmt, ...);
typedef void (*Com_Error_t)(errorParm_t code, const char* fmt, ...);
extern Com_Error_t Com_Error;
typedef void (*DB_LoadXAssets_t)(XZoneInfo* zoneInfo, unsigned int zoneCount, int sync);

View File

@ -1,4 +1,5 @@
#include <std_include.hpp>
#include "context.hpp"
namespace game::scripting
@ -13,10 +14,10 @@ namespace game::scripting
"_event_listener_handle");
chai->add(chaiscript::fun(
[](event_listener_handle& lhs, const event_listener_handle& rhs) -> event_listener_handle&
{
return lhs = rhs;
}), "=");
[](event_listener_handle& lhs, const event_listener_handle& rhs) -> event_listener_handle&
{
return lhs = rhs;
}), "=");
chai->add(chaiscript::fun([this](const event_listener_handle& handle)
{
@ -45,77 +46,112 @@ namespace game::scripting
}
void event_handler::dispatch_to_specific_listeners(event* event,
const std::vector<chaiscript::Boxed_Value>& arguments)
const std::vector<chaiscript::Boxed_Value>& arguments)
{
for (auto listener : this->event_listeners_)
this->event_listeners_.access([&](task_list& tasks)
{
if (listener->event == event->name && listener->entity_id == event->entity_id)
for (auto listener = tasks.begin(); listener != tasks.end();)
{
if (listener->is_volatile)
if (listener->event == event->name && listener->entity_id == event->entity_id)
{
this->event_listeners_.remove(listener);
if (listener->is_volatile)
{
listener = tasks.erase(listener);
continue;
}
listener->callback(arguments);
}
listener->callback(arguments);
++listener;
}
}
});
}
void event_handler::dispatch_to_generic_listeners(event* event,
const std::vector<chaiscript::Boxed_Value>& arguments)
const std::vector<chaiscript::Boxed_Value>& arguments)
{
for (auto listener : this->generic_event_listeners_)
this->generic_event_listeners_.access([&](generic_task_list& tasks)
{
if (listener->event == event->name)
for (auto listener = tasks.begin(); listener != tasks.end();)
{
if (listener->is_volatile)
if (listener->event == event->name)
{
this->generic_event_listeners_.remove(listener);
if (listener->is_volatile)
{
listener = tasks.erase(listener);
continue;
}
listener->callback(entity(this->context_, event->entity_id), arguments);
}
listener->callback(entity(this->context_, event->entity_id), arguments);
++listener;
}
}
});
}
event_listener_handle event_handler::add_event_listener(event_listener listener)
{
listener.id = ++this->current_listener_id_;
this->event_listeners_.add(listener);
this->event_listeners_.access([listener](task_list& tasks)
{
tasks.emplace_back(std::move(listener));
});
return {listener.id};
}
event_listener_handle event_handler::add_event_listener(generic_event_listener listener)
{
listener.id = ++this->current_listener_id_;
this->generic_event_listeners_.add(listener);
this->generic_event_listeners_.access([listener](generic_task_list& tasks)
{
tasks.emplace_back(std::move(listener));
});
return {listener.id};
}
void event_handler::clear()
{
this->event_listeners_.clear();
this->generic_event_listeners_.clear();
this->event_listeners_.access([](task_list& tasks)
{
tasks.clear();
});
this->generic_event_listeners_.access([](generic_task_list& tasks)
{
tasks.clear();
});
}
void event_handler::remove(const event_listener_handle& handle)
{
for (const auto task : this->event_listeners_)
this->event_listeners_.access([handle](task_list& tasks)
{
if (task->id == handle.id)
for (auto i = tasks.begin(); i != tasks.end();)
{
this->event_listeners_.remove(task);
return;
}
}
if (i->id == handle.id)
{
i = tasks.erase(i);
return;
}
for (const auto task : this->generic_event_listeners_)
{
if (task->id == handle.id)
{
this->generic_event_listeners_.remove(task);
return;
++i;
}
}
});
this->generic_event_listeners_.access([handle](generic_task_list& tasks)
{
for (auto i = tasks.begin(); i != tasks.end();)
{
if (i->id == handle.id)
{
i = tasks.erase(i);
return;
}
++i;
}
});
}
}

View File

@ -1,5 +1,5 @@
#pragma once
#include "utils/concurrent_list.hpp"
#include <utils/concurrency.hpp>
#include "entity.hpp"
#include "event.hpp"
@ -46,8 +46,11 @@ namespace game::scripting
context* context_;
std::atomic_int64_t current_listener_id_ = 0;
utils::concurrent_list<event_listener> event_listeners_;
utils::concurrent_list<generic_event_listener> generic_event_listeners_;
using task_list = std::vector<event_listener>;
utils::concurrency::container<task_list> event_listeners_;
using generic_task_list = std::vector<generic_event_listener>;
utils::concurrency::container<generic_task_list> generic_event_listeners_;
void dispatch_to_specific_listeners(event* event, const std::vector<chaiscript::Boxed_Value>& arguments);
void dispatch_to_generic_listeners(event* event, const std::vector<chaiscript::Boxed_Value>& arguments);

View File

@ -1,4 +1,5 @@
#include <std_include.hpp>
#include "context.hpp"
namespace game::scripting
@ -17,16 +18,16 @@ namespace game::scripting
}), "=");
chai->add(chaiscript::fun(
[this](const std::function<void()>& callback, const long long milliseconds) -> task_handle
{
return this->add(callback, milliseconds, true);
}), "setTimeout");
[this](const std::function<void()>& callback, const long long milliseconds) -> task_handle
{
return this->add(callback, milliseconds, true);
}), "setTimeout");
chai->add(chaiscript::fun(
[this](const std::function<void()>& callback, const long long milliseconds) -> task_handle
{
return this->add(callback, milliseconds, false);
}), "setInterval");
[this](const std::function<void()>& callback, const long long milliseconds) -> task_handle
{
return this->add(callback, milliseconds, false);
}), "setInterval");
const auto clear = [this](const task_handle& handle)
{
@ -40,25 +41,40 @@ namespace game::scripting
void scheduler::run_frame()
{
for (auto task : this->tasks_)
this->tasks_.access([&](task_list& tasks)
{
const auto now = std::chrono::steady_clock::now();
if ((now - task->last_execution) > task->delay)
for (auto i = tasks.begin(); i != tasks.end();)
{
task->last_execution = now;
if (task->is_volatile)
const auto now = std::chrono::steady_clock::now();
const auto diff = now - i->last_execution;
if (diff < i->delay)
{
this->tasks_.remove(task);
++i;
continue;
}
task->callback();
i->last_execution = now;
if (i->is_volatile)
{
i = tasks.erase(i);
}
else
{
i->callback();
++i;
}
}
}
});
}
void scheduler::clear()
{
this->tasks_.clear();
this->tasks_.access([&](task_list& tasks)
{
tasks.clear();
});
}
task_handle scheduler::add(const std::function<void()>& callback, const long long milliseconds,
@ -77,20 +93,28 @@ namespace game::scripting
task.last_execution = std::chrono::steady_clock::now();
task.id = ++this->current_task_id_;
this->tasks_.add(task);
this->tasks_.access([&task](task_list& tasks)
{
tasks.emplace_back(std::move(task));
});
return {task.id};
}
void scheduler::remove(const task_handle& handle)
{
for (auto task : this->tasks_)
this->tasks_.access([&](task_list& tasks)
{
if (task->id == handle.id)
for (auto i = tasks.begin(); i != tasks.end();)
{
this->tasks_.remove(task);
break;
if (i->id == handle.id)
{
i = tasks.erase(i);
break;
}
++i;
}
}
});
}
}

View File

@ -1,5 +1,5 @@
#pragma once
#include "utils/concurrent_list.hpp"
#include <utils/concurrency.hpp>
namespace game::scripting
{
@ -8,7 +8,7 @@ namespace game::scripting
class task_handle
{
public:
unsigned long long id = 0;
std::uint64_t id = 0;
};
class task final : public task_handle
@ -31,7 +31,8 @@ namespace game::scripting
private:
context* context_;
utils::concurrent_list<task> tasks_;
using task_list = std::vector<task>;
utils::concurrency::container<task_list> tasks_;
std::atomic_int64_t current_task_id_ = 0;
task_handle add(const std::function<void()>& callback, long long milliseconds, bool is_volatile);

View File

@ -374,6 +374,18 @@ namespace game
};
#pragma pack(pop)
enum errorParm_t
{
ERR_FATAL = 0x0,
ERR_DROP = 0x1,
ERR_SERVERDISCONNECT = 0x2,
ERR_DISCONNECT = 0x3,
ERR_SCRIPT = 0x4,
ERR_SCRIPT_DROP = 0x5,
ERR_LOCALIZATION = 0x6,
ERR_MAPLOADERRORSUMMARY = 0x7
};
enum svscmd_type
{
SV_CMD_CAN_IGNORE,

View File

@ -21,7 +21,8 @@ public:
void post_start() override
{
scheduler::on_frame(std::bind(&console::log_messages, this));
scheduler::loop(std::bind(&console::log_messages, this), scheduler::pipeline::main);
this->console_runner_ = std::thread(std::bind(&console::runner, this));
}

View File

@ -24,7 +24,7 @@ public:
Discord_Initialize("531526691319971880", &handlers, 1, nullptr);
scheduler::on_frame(Discord_RunCallbacks);
scheduler::loop(Discord_RunCallbacks, scheduler::pipeline::main);
}
void pre_destroy() override

View File

@ -1,103 +1,181 @@
#include <std_include.hpp>
#include <loader/module_loader.hpp>
#include <utils/string.hpp>
#include <utils/hook.hpp>
#include <utils/thread.hpp>
#include <utils/concurrency.hpp>
#include "game/game.hpp"
#include "scheduler.hpp"
std::mutex scheduler::mutex_;
std::queue<std::pair<std::string, int>> scheduler::errors_;
utils::concurrent_list<std::function<void()>> scheduler::callbacks_;
utils::concurrent_list<std::function<void()>> scheduler::single_callbacks_;
void scheduler::on_frame(const std::function<void()>& callback)
namespace
{
std::lock_guard _(mutex_);
callbacks_.add(callback);
constexpr bool cond_continue = false;
constexpr bool cond_end = true;
struct task
{
std::function<bool()> handler{};
std::chrono::milliseconds interval{};
std::chrono::high_resolution_clock::time_point last_call{};
};
using task_list = std::vector<task>;
class task_pipeline
{
public:
void add(task&& task)
{
new_callbacks_.access([&task](task_list& tasks)
{
tasks.emplace_back(std::move(task));
});
}
void execute()
{
callbacks_.access([&](task_list& tasks)
{
this->merge_callbacks();
for (auto i = tasks.begin(); i != tasks.end();)
{
const auto now = std::chrono::high_resolution_clock::now();
const auto diff = now - i->last_call;
if (diff < i->interval)
{
++i;
continue;
}
i->last_call = now;
const auto res = i->handler();
if (res == cond_end)
{
i = tasks.erase(i);
}
else
{
++i;
}
}
});
}
private:
utils::concurrency::container<task_list> new_callbacks_;
utils::concurrency::container<task_list, std::recursive_mutex> callbacks_;
void merge_callbacks()
{
callbacks_.access([&](task_list& tasks)
{
new_callbacks_.access([&](task_list& new_tasks)
{
tasks.insert(tasks.end(), std::move_iterator<task_list::iterator>(new_tasks.begin()), std::move_iterator<task_list::iterator>(new_tasks.end()));
new_tasks = {};
});
});
}
};
volatile bool kill = false;
std::thread thread;
task_pipeline pipelines[scheduler::pipeline::count];
}
void scheduler::once(const std::function<void()>& callback)
void scheduler::execute(const pipeline type)
{
std::lock_guard _(mutex_);
single_callbacks_.add(callback);
assert(type >= 0 && type < pipeline::count);
pipelines[type].execute();
}
void scheduler::error(const std::string& message, int level)
void scheduler::r_end_frame_stub()
{
std::lock_guard _(mutex_);
errors_.emplace(message, level);
reinterpret_cast<void(*)()>(SELECT_VALUE(0x4193D0, 0x67F840, 0x0))();
execute(pipeline::renderer);
}
void scheduler::frame_stub()
void scheduler::g_glass_update_stub()
{
reinterpret_cast<void(*)()>(SELECT_VALUE(0x4E3730, 0x505BB0, 0x481EA0))();
execute(pipeline::server);
}
void scheduler::main_frame_stub()
{
execute();
reinterpret_cast<void(*)()>(SELECT_VALUE(0x458600, 0x556470, 0x4DB070))();
execute(pipeline::main);
}
__declspec(naked) void scheduler::execute()
void scheduler::schedule(const std::function<bool()>& callback, const pipeline type,
const std::chrono::milliseconds delay)
{
__asm
{
call execute_error
call execute_safe
retn
}
assert(type >= 0 && type < pipeline::count);
task task;
task.handler = callback;
task.interval = delay;
task.last_call = std::chrono::high_resolution_clock::now();
pipelines[type].add(std::move(task));
}
void scheduler::execute_safe()
void scheduler::loop(const std::function<void()>& callback, const pipeline type,
const std::chrono::milliseconds delay)
{
for (auto callback : callbacks_)
schedule([callback]()
{
(*callback)();
}
for (auto callback : single_callbacks_)
{
single_callbacks_.remove(callback);
(*callback)();
}
callback();
return cond_continue;
}, type, delay);
}
void scheduler::execute_error()
void scheduler::once(const std::function<void()>& callback, const pipeline type,
const std::chrono::milliseconds delay)
{
const char* message = nullptr;
int level = 0;
if (get_next_error(&message, &level) && message)
schedule([callback]()
{
game::native::Com_Error(level, "%s", message);
}
callback();
return cond_end;
}, type, delay);
}
bool scheduler::get_next_error(const char** error_message, int* error_level)
void scheduler::post_start()
{
std::lock_guard _(mutex_);
if (errors_.empty())
thread = utils::thread::create_named_thread("Async Scheduler", []()
{
*error_message = nullptr;
return false;
}
const auto error = errors_.front();
errors_.pop();
*error_level = error.second;
*error_message = utils::string::va("%s", error.first.data());
return true;
while (!kill)
{
execute(pipeline::async);
std::this_thread::sleep_for(10ms);
}
});
}
void scheduler::post_load()
{
utils::hook(SELECT_VALUE(0x44C7DB, 0x55688E, 0x4DB324), frame_stub, HOOK_CALL).install()->quick();
utils::hook(SELECT_VALUE(0x44C7DB, 0x55688E, 0x4DB324), main_frame_stub, HOOK_CALL).install()->quick();
if (!game::is_dedi())
{
utils::hook(SELECT_VALUE(0x57F7F8, 0x4978E2, 0x0), r_end_frame_stub, HOOK_CALL).install()->quick();
}
// Hook a function inside G_RunFrame. Fixes TLS issues
utils::hook(SELECT_VALUE(0x52EFBC, 0x50CEC6, 0x48B277), g_glass_update_stub, HOOK_CALL).install()->quick();
}
void scheduler::pre_destroy()
{
std::lock_guard _(mutex_);
callbacks_.clear();
single_callbacks_.clear();
kill = true;
if (thread.joinable())
{
thread.join();
}
}
REGISTER_MODULE(scheduler);

View File

@ -1,27 +1,40 @@
#pragma once
#include "utils/concurrent_list.hpp"
class scheduler final : public module
{
public:
static void on_frame(const std::function<void()>& callback);
static void once(const std::function<void()>& callback);
enum pipeline
{
// Asynchronuous pipeline, disconnected from the game
async = 0,
static void error(const std::string& message, int level);
// The game's rendering pipeline
renderer,
// The game's server thread
server,
// The game's main thread
main,
count,
};
void post_start() override;
void post_load() override;
void pre_destroy() override;
static void schedule(const std::function<bool()>& callback, pipeline type = pipeline::async,
std::chrono::milliseconds delay = 0ms);
static void loop(const std::function<void()>& callback, pipeline type = pipeline::async,
std::chrono::milliseconds delay = 0ms);
static void once(const std::function<void()>& callback, pipeline type = pipeline::async,
std::chrono::milliseconds delay = 0ms);
private:
static std::mutex mutex_;
static std::queue<std::pair<std::string, int>> errors_;
static utils::concurrent_list<std::function<void()>> callbacks_;
static utils::concurrent_list<std::function<void()>> single_callbacks_;
static void execute(const pipeline type);
static void frame_stub();
static void execute();
static void execute_safe();
static void execute_error();
static bool get_next_error(const char** error_message, int* error_level);
static void r_end_frame_stub();
static void g_glass_update_stub();
static void main_frame_stub();
};

View File

@ -48,7 +48,7 @@ private:
{
const auto script_dir = "open-iw5/scripts/"s;
if(!utils::io::directory_exists(script_dir))
if (!utils::io::directory_exists(script_dir))
{
return;
}
@ -115,7 +115,10 @@ private:
printf("%s\n", e.what());
printf("**************************************\n\n");
scheduler::error("Script execution error\n(see console for actual details)\n", 5);
scheduler::once([]
{
game::native::Com_Error(game::native::errorParm_t::ERR_SCRIPT, "Script execution error\n(see console for actual details)\n");
}, scheduler::pipeline::main);
}
static void start_execution_stub()

View File

@ -1,10 +1,12 @@
#include <std_include.hpp>
#include <loader/module_loader.hpp>
#include <utils/hook.hpp>
#include <utils/string.hpp>
#include "game/game.hpp"
#include "test_clients.hpp"
#include "command.hpp"
#include "scheduler.hpp"
bool test_clients::can_add()
{
@ -91,7 +93,7 @@ game::native::gentity_s* test_clients::sv_add_test_client()
return client->gentity;
}
void test_clients::spawn()
void test_clients::gscr_add_test_client()
{
const auto* ent = test_clients::sv_add_test_client();
@ -101,6 +103,35 @@ void test_clients::spawn()
}
}
void test_clients::spawn(const int count)
{
for (int i = 0; i < count; ++i)
{
scheduler::once([]()
{
auto* ent = sv_add_test_client();
if (ent == nullptr) return;
game::native::Scr_AddEntityNum(ent->s.number, 0);
scheduler::once([ent]()
{
game::native::Scr_AddString("autoassign");
game::native::Scr_AddString("team_marinesopfor");
game::native::Scr_Notify(ent, static_cast<std::uint16_t>(game::native::SL_GetString("menuresponse", 0)), 2);
scheduler::once([ent]()
{
game::native::Scr_AddString(utils::string::va("class%i", std::rand() % 5));
game::native::Scr_AddString("changeclass");
game::native::Scr_Notify(ent, static_cast<std::uint16_t>(game::native::SL_GetString("menuresponse", 0)), 2);
}, scheduler::pipeline::server, 2s);
}, scheduler::pipeline::server, 1s);
}, scheduler::pipeline::server, 2s * (i + 1));
}
}
void test_clients::scr_shutdown_system_mp_stub(unsigned char sys)
{
game::native::SV_DropAllBots();
@ -158,11 +189,15 @@ void test_clients::post_load()
if (game::is_mp()) this->patch_mp();
else return; // No sp/dedi bots for now :(
command::add("spawnBot", []()
command::add("spawnBot", [](const command::params& params)
{
// Because I am unable to expand the scheduler at the moment
// we only get one bot at the time
test_clients::spawn();
if (params.size() < 2)
{
return;
}
const auto count = std::atoi(params.get(1));
test_clients::spawn(count);
});
}
@ -180,7 +215,7 @@ void test_clients::patch_mp()
utils::hook(0x576DCC, &test_clients::check_timeouts_stub_mp, HOOK_JUMP).install()->quick(); // SV_CheckTimeouts
// Replace nullsubbed gsc func "GScr_AddTestClient" with our spawn
utils::hook::set<void(*)()>(0x8AC8DC, test_clients::spawn);
utils::hook::set<void(*)()>(0x8AC8DC, test_clients::gscr_add_test_client);
}
REGISTER_MODULE(test_clients);

View File

@ -10,7 +10,8 @@ private:
static bool can_add();
static game::native::gentity_s* sv_add_test_client();
static void spawn();
static void gscr_add_test_client();
static void spawn(int count);
static void scr_shutdown_system_mp_stub(unsigned char sys);

46
src/utils/concurrency.hpp Normal file
View File

@ -0,0 +1,46 @@
#pragma once
#include <mutex>
namespace utils::concurrency
{
template <typename T, typename MutexType = std::mutex>
class container
{
public:
template <typename R = void, typename F>
R access(F&& accessor) const
{
std::lock_guard<MutexType> _{mutex_};
return accessor(object_);
}
template <typename R = void, typename F>
R access(F&& accessor)
{
std::lock_guard<MutexType> _{mutex_};
return accessor(object_);
}
template <typename R = void, typename F>
R access_with_lock(F&& accessor) const
{
std::unique_lock<MutexType> lock{mutex_};
return accessor(object_, lock);
}
template <typename R = void, typename F>
R access_with_lock(F&& accessor)
{
std::unique_lock<MutexType> lock{mutex_};
return accessor(object_, lock);
}
T& get_raw() { return object_; }
const T& get_raw() const { return object_; }
private:
mutable MutexType mutex_{};
T object_{};
};
}

View File

@ -1,131 +0,0 @@
#pragma once
namespace utils
{
template <typename T>
class concurrent_list final
{
public:
class element final
{
public:
explicit element(std::recursive_mutex* mutex, std::shared_ptr<T> entry = {},
std::shared_ptr<element> next = {}) :
mutex_(mutex),
entry_(std::move(entry)),
next_(std::move(next))
{
}
void remove(const std::shared_ptr<T>& element)
{
std::lock_guard _(*this->mutex_);
if (!this->next_) return;
if (this->next_->entry_.get() == element.get())
{
this->next_ = this->next_->next_;
}
else
{
this->next_->remove(element);
}
}
std::shared_ptr<element> get_next() const
{
std::lock_guard _(*this->mutex_);
return this->next_;
}
std::shared_ptr<T> operator*() const
{
std::lock_guard _(*this->mutex_);
return this->entry_;
}
element& operator++()
{
std::lock_guard _(*this->mutex_);
*this = this->next_ ? *this->next_ : element(this->mutex_);
return *this;
}
element operator++(int)
{
std::lock_guard _(*this->mutex_);
auto result = *this;
this->operator++();
return result;
}
bool operator==(const element& other)
{
std::lock_guard _(*this->mutex_);
return this->entry_.get() == other.entry_.get();
}
bool operator!=(const element& other)
{
std::lock_guard _(*this->mutex_);
return !(*this == other);
}
private:
std::recursive_mutex* mutex_;
std::shared_ptr<T> entry_;
std::shared_ptr<element> next_;
};
element begin()
{
std::lock_guard _(this->mutex_);
return this->entry_ ? *this->entry_ : this->end();
}
element end()
{
std::lock_guard _(this->mutex_);
return element(&this->mutex_);
}
void remove(const element& entry)
{
std::lock_guard _(this->mutex_);
this->remove(*entry);
}
void remove(const std::shared_ptr<T>& element)
{
std::lock_guard _(this->mutex_);
if (!this->entry_) return;
if ((**this->entry_).get() == element.get())
{
this->entry_ = this->entry_->get_next();
}
else
{
this->entry_->remove(element);
}
}
void add(const T& object)
{
std::lock_guard _(this->mutex_);
const auto object_ptr = std::make_shared<T>(object);
this->entry_ = std::make_shared<element>(&this->mutex_, object_ptr, this->entry_);
}
void clear()
{
std::lock_guard _(this->mutex_);
this->entry_ = {};
}
private:
std::recursive_mutex mutex_;
std::shared_ptr<element> entry_;
};
}

View File

@ -36,6 +36,19 @@ namespace utils::string
return text;
}
std::wstring convert(const std::string& str)
{
std::wstring result;
result.reserve(str.size());
for (const auto& chr : str)
{
result.push_back(static_cast<wchar_t>(chr));
}
return result;
}
std::string dump_hex(const std::string& data, const std::string& separator)
{
std::string result;

View File

@ -77,5 +77,7 @@ namespace utils::string
std::string to_lower(std::string text);
std::string to_upper(std::string text);
std::wstring convert(const std::string& str);
std::string dump_hex(const std::string& data, const std::string& separator = " ");
}

130
src/utils/thread.cpp Normal file
View File

@ -0,0 +1,130 @@
#include <std_include.hpp>
#include "thread.hpp"
#include "string.hpp"
#include "nt.hpp"
#include <TlHelp32.h>
#include <gsl/gsl>
namespace utils::thread
{
bool set_name(const HANDLE t, const std::string& name)
{
const nt::library kernel32("kernel32.dll");
if (!kernel32)
{
return false;
}
const auto set_description = kernel32.get_proc<HRESULT(WINAPI*)(HANDLE, PCWSTR)>("SetThreadDescription");
if (!set_description)
{
return false;
}
return SUCCEEDED(set_description(t, string::convert(name).data()));
}
bool set_name(const DWORD id, const std::string& name)
{
auto* const t = OpenThread(THREAD_SET_LIMITED_INFORMATION, FALSE, id);
if (!t) return false;
const auto _ = gsl::finally([t]()
{
CloseHandle(t);
});
return set_name(t, name);
}
bool set_name(std::thread& t, const std::string& name)
{
return set_name(t.native_handle(), name);
}
bool set_name(const std::string& name)
{
return set_name(GetCurrentThread(), name);
}
std::vector<DWORD> get_thread_ids()
{
auto* const h = CreateToolhelp32Snapshot(TH32CS_SNAPTHREAD, GetCurrentProcessId());
if (h == INVALID_HANDLE_VALUE)
{
return {};
}
const auto _ = gsl::finally([h]()
{
CloseHandle(h);
});
THREADENTRY32 entry{};
entry.dwSize = sizeof(entry);
if (!Thread32First(h, &entry))
{
return {};
}
std::vector<DWORD> ids{};
do
{
const auto check_size = entry.dwSize < FIELD_OFFSET(THREADENTRY32, th32OwnerProcessID)
+ sizeof(entry.th32OwnerProcessID);
entry.dwSize = sizeof(entry);
if (check_size && entry.th32OwnerProcessID == GetCurrentProcessId())
{
ids.emplace_back(entry.th32ThreadID);
}
} while (Thread32Next(h, &entry));
return ids;
}
void for_each_thread(const std::function<void(HANDLE)>& callback)
{
const auto ids = get_thread_ids();
for (const auto& id : ids)
{
auto* const thread = OpenThread(THREAD_ALL_ACCESS, FALSE, id);
if (thread != nullptr)
{
const auto _ = gsl::finally([thread]()
{
CloseHandle(thread);
});
callback(thread);
}
}
}
void suspend_other_threads()
{
for_each_thread([](const HANDLE thread)
{
if (GetThreadId(thread) != GetCurrentThreadId())
{
SuspendThread(thread);
}
});
}
void resume_other_threads()
{
for_each_thread([](const HANDLE thread)
{
if (GetThreadId(thread) != GetCurrentThreadId())
{
ResumeThread(thread);
}
});
}
}

24
src/utils/thread.hpp Normal file
View File

@ -0,0 +1,24 @@
#pragma once
#include <thread>
namespace utils::thread
{
bool set_name(HANDLE t, const std::string& name);
bool set_name(DWORD id, const std::string& name);
bool set_name(std::thread& t, const std::string& name);
bool set_name(const std::string& name);
template <typename ...Args>
std::thread create_named_thread(const std::string& name, Args&&... args)
{
auto t = std::thread(std::forward<Args>(args)...);
set_name(t, name);
return t;
}
std::vector<DWORD> get_thread_ids();
void for_each_thread(const std::function<void(HANDLE)>& callback);
void suspend_other_threads();
void resume_other_threads();
}