From b718803ecb44c5cab653fed0c36b3f2430bf5980 Mon Sep 17 00:00:00 2001 From: Diavolo Date: Tue, 28 Jun 2022 09:26:43 +0200 Subject: [PATCH] [Bans] Refactor --- src/Components/Modules/Bans.cpp | 231 ++++++++++++++++------------ src/Components/Modules/Bans.hpp | 16 +- src/Components/Modules/Download.cpp | 4 +- src/Components/Modules/Download.hpp | 3 +- src/STDInclude.hpp | 1 + src/Utils/NamedMutex.cpp | 43 ++++++ src/Utils/NamedMutex.hpp | 24 +++ 7 files changed, 209 insertions(+), 113 deletions(-) create mode 100644 src/Utils/NamedMutex.cpp create mode 100644 src/Utils/NamedMutex.hpp diff --git a/src/Components/Modules/Bans.cpp b/src/Components/Modules/Bans.cpp index cf13ed16..45382bdc 100644 --- a/src/Components/Modules/Bans.cpp +++ b/src/Components/Modules/Bans.cpp @@ -2,14 +2,18 @@ namespace Components { - std::recursive_mutex Bans::AccessMutex; - - bool Bans::IsBanned(Bans::Entry entry) + // Have only one instance of IW4x read/write the file + std::unique_lock Bans::Lock() { - std::lock_guard _(Bans::AccessMutex); + static Utils::NamedMutex mutex{"iw4x-ban-list-lock"}; + std::unique_lock lock{mutex}; + return lock; + } - Bans::BanList list; - Bans::LoadBans(&list); + bool Bans::IsBanned(const banEntry& entry) + { + BanList list; + LoadBans(&list); if (entry.first.bits) { @@ -24,7 +28,7 @@ namespace Components if (entry.second.full) { - for (auto& ipEntry : list.ipList) + for (const auto& ipEntry : list.ipList) { if (ipEntry.full == entry.second.full) { @@ -36,17 +40,15 @@ namespace Components return false; } - void Bans::InsertBan(Bans::Entry entry) + void Bans::InsertBan(const banEntry& entry) { - std::lock_guard _(Bans::AccessMutex); - - Bans::BanList list; - Bans::LoadBans(&list); + BanList list; + LoadBans(&list); if (entry.first.bits) { bool found = false; - for (auto& idEntry : list.idList) + for (const auto& idEntry : list.idList) { if (idEntry.bits == entry.first.bits) { @@ -64,7 +66,7 @@ namespace Components if (entry.second.full) { bool found = false; - for (auto& ipEntry : list.ipList) + for (const auto& ipEntry : list.ipList) { if (ipEntry.full == entry.second.full) { @@ -79,12 +81,14 @@ namespace Components } } - Bans::SaveBans(&list); + SaveBans(&list); } - void Bans::SaveBans(BanList* list) + void Bans::SaveBans(const BanList* list) { - std::lock_guard _(Bans::AccessMutex); + assert(list != nullptr); + + const auto _ = Lock(); std::vector idVector; std::vector ipVector; @@ -109,95 +113,85 @@ namespace Components { "id", idVector }, }; - FileSystem::FileWriter ban("bans.json"); - ban.write(bans.dump()); + FileSystem::FileWriter ("bans.json").write(bans.dump()); } - void Bans::LoadBans(Bans::BanList* list) + void Bans::LoadBans(BanList* list) { - std::lock_guard _(Bans::AccessMutex); + assert(list != nullptr); + + const auto _ = Lock(); FileSystem::File bans("bans.json"); - if (bans.exists()) + if (!bans.exists()) { - std::string error; - json11::Json banData = json11::Json::parse(bans.getBuffer(), error); + Logger::Debug("bans.json does not exist"); + return; + } - if (!error.empty()) + std::string error; + const auto banData = json11::Json::parse(bans.getBuffer(), error); + + if (!error.empty()) + { + Logger::PrintError(Game::CON_CHANNEL_ERROR, "Failed to parse bans.json: {}\n", error); + return; + } + + if (!banData.is_object()) + { + Logger::Debug("bans.json contains invalid data"); + return; + } + + const auto& idList = banData["id"]; + const auto& ipList = banData["ip"]; + + if (idList.is_array()) + { + for (auto &idEntry : idList.array_items()) { - Logger::Error(Game::ERR_FATAL, "Failed to parse bans (bans.json): {}", error); - } - - if (!list) return; - - if (banData.is_object()) - { - auto idList = banData["id"]; - auto ipList = banData["ip"]; - - if (idList.is_array()) + if (idEntry.is_string()) { - for (auto &idEntry : idList.array_items()) - { - if (idEntry.is_string()) - { - SteamID id; - id.bits = strtoull(idEntry.string_value().data(), nullptr, 16); + SteamID id; + id.bits = strtoull(idEntry.string_value().data(), nullptr, 16); - list->idList.push_back(id); - } - } + list->idList.push_back(id); } + } + } - if (ipList.is_array()) + if (ipList.is_array()) + { + for (auto &ipEntry : ipList.array_items()) + { + if (ipEntry.is_string()) { - for (auto &ipEntry : ipList.array_items()) - { - if (ipEntry.is_string()) - { - Network::Address addr(ipEntry.string_value()); + Network::Address addr(ipEntry.string_value()); - list->ipList.push_back(addr.getIP()); - } - } + list->ipList.push_back(addr.getIP()); } } } } - void Bans::BanClientNum(int num, const std::string& reason) + void Bans::BanClient(Game::client_t* cl, const std::string& reason) { - if (!Dvar::Var("sv_running").get()) - { - Logger::Print("Server is not running.\n"); - return; - } - - if (*Game::svs_clientCount <= num) - { - Logger::Print("Player {} is not on the server\n", num); - return; - } - - Game::client_t* client = &Game::svs_clients[num]; - SteamID guid; - guid.bits = client->steamID; + guid.bits = cl->steamID; - Bans::InsertBan({guid, client->netchan.remoteAddress.ip}); + InsertBan({guid, cl->netchan.remoteAddress.ip}); - Game::SV_GameDropClient(num, reason.data()); + Game::SV_DropClient(cl, reason.data(), true); } void Bans::UnbanClient(SteamID id) { - std::lock_guard _(Bans::AccessMutex); + BanList list; + LoadBans(&list); - Bans::BanList list; - Bans::LoadBans(&list); - - auto entry = std::find_if(list.idList.begin(), list.idList.end(), [&id](SteamID& entry) + const auto entry = std::find_if(list.idList.begin(), list.idList.end(), [&id](const SteamID& entry) { return id.bits == entry.bits; }); @@ -207,17 +201,15 @@ namespace Components list.idList.erase(entry); } - Bans::SaveBans(&list); + SaveBans(&list); } void Bans::UnbanClient(Game::netIP_t ip) { - std::lock_guard _(Bans::AccessMutex); + BanList list; + LoadBans(&list); - Bans::BanList list; - Bans::LoadBans(&list); - - auto entry = std::find_if(list.ipList.begin(), list.ipList.end(), [&ip](Game::netIP_t& entry) + const auto entry = std::find_if(list.ipList.begin(), list.ipList.end(), [&ip](const Game::netIP_t& entry) { return ip.full == entry.full; }); @@ -227,31 +219,75 @@ namespace Components list.ipList.erase(entry); } - Bans::SaveBans(&list); + SaveBans(&list); } Bans::Bans() { - Command::Add("banclient", [](Command::Params* params) + Command::Add("banClient", [](Command::Params* params) { - if (params->size() < 2) return; + if (!Dvar::Var("sv_running").get()) + { + Logger::Print("Server is not running.\n"); + return; + } - std::string reason = "EXE_ERR_BANNED_PERM"; - if (params->size() >= 3) reason = params->join(2); + if (params->size() < 2) + { + Logger::Print("{} : permanently ban a client\n", params->get(0)); + return; + } - Bans::BanClientNum(atoi(params->get(1)), reason); + const auto* input = params->get(1); + + for (auto i = 0; input[i] != '\0'; ++i) + { + if (input[i] < '0' || input[i] > '9') + { + Logger::Print("Bad slot number: {}\n", input); + return; + } + } + + const auto num = std::atoi(input); + + if (num < 0 || num >= *Game::svs_clientCount) + { + Logger::Print("Bad client slot: {}\n", num); + return; + } + + const auto* cl = &Game::svs_clients[num]; + if (cl->state == Game::CS_FREE) + { + Logger::Print("Client {} is not active\n", num); + return; + } + + const std::string reason = params->size() < 3 ? "EXE_ERR_BANNED_PERM" : params->join(2); + Bans::BanClient(&Game::svs_clients[num], reason); }); - Command::Add("unbanclient", [](Command::Params* params) + Command::Add("unbanClient", [](Command::Params* params) { - if (params->size() < 2) return; + if (!Dvar::Var("sv_running").get()) + { + Logger::Print("Server is not running.\n"); + return; + } - std::string type = params->get(1); + if (params->size() < 3) + { + Logger::Print("{} \n", params->get(0)); + return; + } + + const auto* type = params->get(1); if (type == "ip"s) { Network::Address address(params->get(2)); - Bans::UnbanClient(address.getIP()); + UnbanClient(address.getIP()); Logger::Print("Unbanned IP {}\n", params->get(2)); @@ -261,17 +297,10 @@ namespace Components SteamID id; id.bits = strtoull(params->get(2), nullptr, 16); - Bans::UnbanClient(id); + UnbanClient(id); Logger::Print("Unbanned GUID {}\n", params->get(2)); } }); - - // Verify the list on startup - Scheduler::OnGameInitialized([] - { - Bans::BanList list; - Bans::LoadBans(&list); - }, Scheduler::Pipeline::SERVER); } } diff --git a/src/Components/Modules/Bans.hpp b/src/Components/Modules/Bans.hpp index fe5c6410..0ba36d12 100644 --- a/src/Components/Modules/Bans.hpp +++ b/src/Components/Modules/Bans.hpp @@ -5,27 +5,27 @@ namespace Components class Bans : public Component { public: - typedef std::pair Entry; + using banEntry = std::pair; Bans(); - static void BanClientNum(int num, const std::string& reason); + static std::unique_lock Lock(); + + static void BanClient(Game::client_t* cl, const std::string& reason); static void UnbanClient(SteamID id); static void UnbanClient(Game::netIP_t ip); - static bool IsBanned(Entry entry); - static void InsertBan(Entry entry); + static bool IsBanned(const banEntry& entry); + static void InsertBan(const banEntry& entry); private: - class BanList + struct BanList { - public: std::vector idList; std::vector ipList; }; - static std::recursive_mutex AccessMutex; static void LoadBans(BanList* list); - static void SaveBans(BanList* list); + static void SaveBans(const BanList* list); }; } diff --git a/src/Components/Modules/Download.cpp b/src/Components/Modules/Download.cpp index d16e76ee..6cda7c4c 100644 --- a/src/Components/Modules/Download.cpp +++ b/src/Components/Modules/Download.cpp @@ -8,7 +8,7 @@ namespace Components std::thread Download::ServerThread; bool Download::Terminate; - bool Download::ServerRunning; + bool Download::ServerRunning; #pragma region Client @@ -889,7 +889,7 @@ namespace Components } }); - Download::ServerRunning = true; + Download::ServerRunning = true; Download::Terminate = false; Download::ServerThread = std::thread([] { diff --git a/src/Components/Modules/Download.hpp b/src/Components/Modules/Download.hpp index 90472069..cf934f34 100644 --- a/src/Components/Modules/Download.hpp +++ b/src/Components/Modules/Download.hpp @@ -1,5 +1,4 @@ #pragma once -#include namespace Components { @@ -210,7 +209,7 @@ namespace Components static std::vector> ScriptDownloads; static std::thread ServerThread; static bool Terminate; - static bool ServerRunning; + static bool ServerRunning; static void DownloadProgress(FileDownload* fDownload, size_t bytes); diff --git a/src/STDInclude.hpp b/src/STDInclude.hpp index 950c5248..d2deaa11 100644 --- a/src/STDInclude.hpp +++ b/src/STDInclude.hpp @@ -129,6 +129,7 @@ using namespace std::literals; #include "Utils/Json.hpp" #include "Utils/Library.hpp" #include "Utils/Maths.hpp" +#include "Utils/NamedMutex.hpp" #include "Utils/String.hpp" #include "Utils/Thread.hpp" #include "Utils/Time.hpp" diff --git a/src/Utils/NamedMutex.cpp b/src/Utils/NamedMutex.cpp new file mode 100644 index 00000000..48bfed99 --- /dev/null +++ b/src/Utils/NamedMutex.cpp @@ -0,0 +1,43 @@ +#include + +namespace Utils +{ + NamedMutex::NamedMutex(const std::string& name) + { + this->handle_ = CreateMutexA(nullptr, FALSE, name.data()); + } + + NamedMutex::~NamedMutex() + { + if (this->handle_) + { + CloseHandle(this->handle_); + } + } + + void NamedMutex::lock() const + { + if (this->handle_) + { + WaitForSingleObject(this->handle_, INFINITE); + } + } + + bool NamedMutex::try_lock(const std::chrono::milliseconds timeout) const + { + if (this->handle_) + { + return WAIT_OBJECT_0 == WaitForSingleObject(this->handle_, static_cast(timeout.count())); + } + + return false; + } + + void NamedMutex::unlock() const noexcept + { + if (this->handle_) + { + ReleaseMutex(this->handle_); + } + } +} diff --git a/src/Utils/NamedMutex.hpp b/src/Utils/NamedMutex.hpp new file mode 100644 index 00000000..15bf892d --- /dev/null +++ b/src/Utils/NamedMutex.hpp @@ -0,0 +1,24 @@ +#pragma once + +namespace Utils +{ + class NamedMutex + { + public: + explicit NamedMutex(const std::string& name); + ~NamedMutex(); + + NamedMutex(NamedMutex&&) = delete; + NamedMutex(const NamedMutex&) = delete; + NamedMutex& operator=(NamedMutex&&) = delete; + NamedMutex& operator=(const NamedMutex&) = delete; + + void lock() const; + // Lockable requirements + [[nodiscard]] bool try_lock(std::chrono::milliseconds timeout = std::chrono::milliseconds{0}) const; + void unlock() const noexcept; + + private: + void* handle_{}; + }; +}