[Bans] Refactor

This commit is contained in:
Diavolo 2022-06-28 09:26:43 +02:00
parent 2daaee8358
commit b718803ecb
No known key found for this signature in database
GPG Key ID: FA77F074E98D98A5
7 changed files with 209 additions and 113 deletions

View File

@ -2,14 +2,18 @@
namespace Components
{
std::recursive_mutex Bans::AccessMutex;
bool Bans::IsBanned(Bans::Entry entry)
// Have only one instance of IW4x read/write the file
std::unique_lock<Utils::NamedMutex> Bans::Lock()
{
std::lock_guard<std::recursive_mutex> _(Bans::AccessMutex);
static Utils::NamedMutex mutex{"iw4x-ban-list-lock"};
std::unique_lock lock{mutex};
return lock;
}
Bans::BanList list;
Bans::LoadBans(&list);
bool Bans::IsBanned(const banEntry& entry)
{
BanList list;
LoadBans(&list);
if (entry.first.bits)
{
@ -24,7 +28,7 @@ namespace Components
if (entry.second.full)
{
for (auto& ipEntry : list.ipList)
for (const auto& ipEntry : list.ipList)
{
if (ipEntry.full == entry.second.full)
{
@ -36,17 +40,15 @@ namespace Components
return false;
}
void Bans::InsertBan(Bans::Entry entry)
void Bans::InsertBan(const banEntry& entry)
{
std::lock_guard<std::recursive_mutex> _(Bans::AccessMutex);
Bans::BanList list;
Bans::LoadBans(&list);
BanList list;
LoadBans(&list);
if (entry.first.bits)
{
bool found = false;
for (auto& idEntry : list.idList)
for (const auto& idEntry : list.idList)
{
if (idEntry.bits == entry.first.bits)
{
@ -64,7 +66,7 @@ namespace Components
if (entry.second.full)
{
bool found = false;
for (auto& ipEntry : list.ipList)
for (const auto& ipEntry : list.ipList)
{
if (ipEntry.full == entry.second.full)
{
@ -79,12 +81,14 @@ namespace Components
}
}
Bans::SaveBans(&list);
SaveBans(&list);
}
void Bans::SaveBans(BanList* list)
void Bans::SaveBans(const BanList* list)
{
std::lock_guard<std::recursive_mutex> _(Bans::AccessMutex);
assert(list != nullptr);
const auto _ = Lock();
std::vector<std::string> idVector;
std::vector<std::string> ipVector;
@ -109,95 +113,85 @@ namespace Components
{ "id", idVector },
};
FileSystem::FileWriter ban("bans.json");
ban.write(bans.dump());
FileSystem::FileWriter ("bans.json").write(bans.dump());
}
void Bans::LoadBans(Bans::BanList* list)
void Bans::LoadBans(BanList* list)
{
std::lock_guard<std::recursive_mutex> _(Bans::AccessMutex);
assert(list != nullptr);
const auto _ = Lock();
FileSystem::File bans("bans.json");
if (bans.exists())
if (!bans.exists())
{
std::string error;
json11::Json banData = json11::Json::parse(bans.getBuffer(), error);
Logger::Debug("bans.json does not exist");
return;
}
if (!error.empty())
std::string error;
const auto banData = json11::Json::parse(bans.getBuffer(), error);
if (!error.empty())
{
Logger::PrintError(Game::CON_CHANNEL_ERROR, "Failed to parse bans.json: {}\n", error);
return;
}
if (!banData.is_object())
{
Logger::Debug("bans.json contains invalid data");
return;
}
const auto& idList = banData["id"];
const auto& ipList = banData["ip"];
if (idList.is_array())
{
for (auto &idEntry : idList.array_items())
{
Logger::Error(Game::ERR_FATAL, "Failed to parse bans (bans.json): {}", error);
}
if (!list) return;
if (banData.is_object())
{
auto idList = banData["id"];
auto ipList = banData["ip"];
if (idList.is_array())
if (idEntry.is_string())
{
for (auto &idEntry : idList.array_items())
{
if (idEntry.is_string())
{
SteamID id;
id.bits = strtoull(idEntry.string_value().data(), nullptr, 16);
SteamID id;
id.bits = strtoull(idEntry.string_value().data(), nullptr, 16);
list->idList.push_back(id);
}
}
list->idList.push_back(id);
}
}
}
if (ipList.is_array())
if (ipList.is_array())
{
for (auto &ipEntry : ipList.array_items())
{
if (ipEntry.is_string())
{
for (auto &ipEntry : ipList.array_items())
{
if (ipEntry.is_string())
{
Network::Address addr(ipEntry.string_value());
Network::Address addr(ipEntry.string_value());
list->ipList.push_back(addr.getIP());
}
}
list->ipList.push_back(addr.getIP());
}
}
}
}
void Bans::BanClientNum(int num, const std::string& reason)
void Bans::BanClient(Game::client_t* cl, const std::string& reason)
{
if (!Dvar::Var("sv_running").get<bool>())
{
Logger::Print("Server is not running.\n");
return;
}
if (*Game::svs_clientCount <= num)
{
Logger::Print("Player {} is not on the server\n", num);
return;
}
Game::client_t* client = &Game::svs_clients[num];
SteamID guid;
guid.bits = client->steamID;
guid.bits = cl->steamID;
Bans::InsertBan({guid, client->netchan.remoteAddress.ip});
InsertBan({guid, cl->netchan.remoteAddress.ip});
Game::SV_GameDropClient(num, reason.data());
Game::SV_DropClient(cl, reason.data(), true);
}
void Bans::UnbanClient(SteamID id)
{
std::lock_guard<std::recursive_mutex> _(Bans::AccessMutex);
BanList list;
LoadBans(&list);
Bans::BanList list;
Bans::LoadBans(&list);
auto entry = std::find_if(list.idList.begin(), list.idList.end(), [&id](SteamID& entry)
const auto entry = std::find_if(list.idList.begin(), list.idList.end(), [&id](const SteamID& entry)
{
return id.bits == entry.bits;
});
@ -207,17 +201,15 @@ namespace Components
list.idList.erase(entry);
}
Bans::SaveBans(&list);
SaveBans(&list);
}
void Bans::UnbanClient(Game::netIP_t ip)
{
std::lock_guard<std::recursive_mutex> _(Bans::AccessMutex);
BanList list;
LoadBans(&list);
Bans::BanList list;
Bans::LoadBans(&list);
auto entry = std::find_if(list.ipList.begin(), list.ipList.end(), [&ip](Game::netIP_t& entry)
const auto entry = std::find_if(list.ipList.begin(), list.ipList.end(), [&ip](const Game::netIP_t& entry)
{
return ip.full == entry.full;
});
@ -227,31 +219,75 @@ namespace Components
list.ipList.erase(entry);
}
Bans::SaveBans(&list);
SaveBans(&list);
}
Bans::Bans()
{
Command::Add("banclient", [](Command::Params* params)
Command::Add("banClient", [](Command::Params* params)
{
if (params->size() < 2) return;
if (!Dvar::Var("sv_running").get<bool>())
{
Logger::Print("Server is not running.\n");
return;
}
std::string reason = "EXE_ERR_BANNED_PERM";
if (params->size() >= 3) reason = params->join(2);
if (params->size() < 2)
{
Logger::Print("{} <client number> : permanently ban a client\n", params->get(0));
return;
}
Bans::BanClientNum(atoi(params->get(1)), reason);
const auto* input = params->get(1);
for (auto i = 0; input[i] != '\0'; ++i)
{
if (input[i] < '0' || input[i] > '9')
{
Logger::Print("Bad slot number: {}\n", input);
return;
}
}
const auto num = std::atoi(input);
if (num < 0 || num >= *Game::svs_clientCount)
{
Logger::Print("Bad client slot: {}\n", num);
return;
}
const auto* cl = &Game::svs_clients[num];
if (cl->state == Game::CS_FREE)
{
Logger::Print("Client {} is not active\n", num);
return;
}
const std::string reason = params->size() < 3 ? "EXE_ERR_BANNED_PERM" : params->join(2);
Bans::BanClient(&Game::svs_clients[num], reason);
});
Command::Add("unbanclient", [](Command::Params* params)
Command::Add("unbanClient", [](Command::Params* params)
{
if (params->size() < 2) return;
if (!Dvar::Var("sv_running").get<bool>())
{
Logger::Print("Server is not running.\n");
return;
}
std::string type = params->get(1);
if (params->size() < 3)
{
Logger::Print("{} <type> <ip or guid>\n", params->get(0));
return;
}
const auto* type = params->get(1);
if (type == "ip"s)
{
Network::Address address(params->get(2));
Bans::UnbanClient(address.getIP());
UnbanClient(address.getIP());
Logger::Print("Unbanned IP {}\n", params->get(2));
@ -261,17 +297,10 @@ namespace Components
SteamID id;
id.bits = strtoull(params->get(2), nullptr, 16);
Bans::UnbanClient(id);
UnbanClient(id);
Logger::Print("Unbanned GUID {}\n", params->get(2));
}
});
// Verify the list on startup
Scheduler::OnGameInitialized([]
{
Bans::BanList list;
Bans::LoadBans(&list);
}, Scheduler::Pipeline::SERVER);
}
}

View File

@ -5,27 +5,27 @@ namespace Components
class Bans : public Component
{
public:
typedef std::pair<SteamID, Game::netIP_t> Entry;
using banEntry = std::pair<SteamID, Game::netIP_t>;
Bans();
static void BanClientNum(int num, const std::string& reason);
static std::unique_lock<Utils::NamedMutex> Lock();
static void BanClient(Game::client_t* cl, const std::string& reason);
static void UnbanClient(SteamID id);
static void UnbanClient(Game::netIP_t ip);
static bool IsBanned(Entry entry);
static void InsertBan(Entry entry);
static bool IsBanned(const banEntry& entry);
static void InsertBan(const banEntry& entry);
private:
class BanList
struct BanList
{
public:
std::vector<SteamID> idList;
std::vector<Game::netIP_t> ipList;
};
static std::recursive_mutex AccessMutex;
static void LoadBans(BanList* list);
static void SaveBans(BanList* list);
static void SaveBans(const BanList* list);
};
}

View File

@ -8,7 +8,7 @@ namespace Components
std::thread Download::ServerThread;
bool Download::Terminate;
bool Download::ServerRunning;
bool Download::ServerRunning;
#pragma region Client
@ -889,7 +889,7 @@ namespace Components
}
});
Download::ServerRunning = true;
Download::ServerRunning = true;
Download::Terminate = false;
Download::ServerThread = std::thread([]
{

View File

@ -1,5 +1,4 @@
#pragma once
#include <Game/Functions.hpp>
namespace Components
{
@ -210,7 +209,7 @@ namespace Components
static std::vector<std::shared_ptr<ScriptDownload>> ScriptDownloads;
static std::thread ServerThread;
static bool Terminate;
static bool ServerRunning;
static bool ServerRunning;
static void DownloadProgress(FileDownload* fDownload, size_t bytes);

View File

@ -129,6 +129,7 @@ using namespace std::literals;
#include "Utils/Json.hpp"
#include "Utils/Library.hpp"
#include "Utils/Maths.hpp"
#include "Utils/NamedMutex.hpp"
#include "Utils/String.hpp"
#include "Utils/Thread.hpp"
#include "Utils/Time.hpp"

43
src/Utils/NamedMutex.cpp Normal file
View File

@ -0,0 +1,43 @@
#include <STDInclude.hpp>
namespace Utils
{
NamedMutex::NamedMutex(const std::string& name)
{
this->handle_ = CreateMutexA(nullptr, FALSE, name.data());
}
NamedMutex::~NamedMutex()
{
if (this->handle_)
{
CloseHandle(this->handle_);
}
}
void NamedMutex::lock() const
{
if (this->handle_)
{
WaitForSingleObject(this->handle_, INFINITE);
}
}
bool NamedMutex::try_lock(const std::chrono::milliseconds timeout) const
{
if (this->handle_)
{
return WAIT_OBJECT_0 == WaitForSingleObject(this->handle_, static_cast<DWORD>(timeout.count()));
}
return false;
}
void NamedMutex::unlock() const noexcept
{
if (this->handle_)
{
ReleaseMutex(this->handle_);
}
}
}

24
src/Utils/NamedMutex.hpp Normal file
View File

@ -0,0 +1,24 @@
#pragma once
namespace Utils
{
class NamedMutex
{
public:
explicit NamedMutex(const std::string& name);
~NamedMutex();
NamedMutex(NamedMutex&&) = delete;
NamedMutex(const NamedMutex&) = delete;
NamedMutex& operator=(NamedMutex&&) = delete;
NamedMutex& operator=(const NamedMutex&) = delete;
void lock() const;
// Lockable requirements
[[nodiscard]] bool try_lock(std::chrono::milliseconds timeout = std::chrono::milliseconds{0}) const;
void unlock() const noexcept;
private:
void* handle_{};
};
}