[Node] Rewrite entire node system

This commit is contained in:
momo5502 2017-06-25 21:54:35 +02:00
parent e841ca48eb
commit b1a91125fc
19 changed files with 547 additions and 1053 deletions

View File

@ -60,6 +60,7 @@ namespace Components
Loader::Register(new ModList()); Loader::Register(new ModList());
Loader::Register(new Monitor()); Loader::Register(new Monitor());
Loader::Register(new Network()); Loader::Register(new Network());
Loader::Register(new Session());
Loader::Register(new Theatre()); Loader::Register(new Theatre());
//Loader::Register(new ClanTags()); //Loader::Register(new ClanTags());
Loader::Register(new Download()); Loader::Register(new Download());

View File

@ -92,6 +92,7 @@ namespace Components
#include "Modules/Logger.hpp" #include "Modules/Logger.hpp"
#include "Modules/Friends.hpp" #include "Modules/Friends.hpp"
#include "Modules/IPCPipe.hpp" #include "Modules/IPCPipe.hpp"
#include "Modules/Session.hpp"
#include "Modules/ClanTags.hpp" #include "Modules/ClanTags.hpp"
#include "Modules/Download.hpp" #include "Modules/Download.hpp"
#include "Modules/Playlist.hpp" #include "Modules/Playlist.hpp"

View File

@ -105,7 +105,7 @@ namespace Components
if (entry->server.getType() == Game::NA_LOOPBACK || (entry->server.getType() == Game::NA_IP && entry->server.getIP().full == 0x0100007F)) entry->server.setType(Game::NA_BAD); if (entry->server.getType() == Game::NA_LOOPBACK || (entry->server.getType() == Game::NA_IP && entry->server.getIP().full == 0x0100007F)) entry->server.setType(Game::NA_BAD);
else if (entry->server.getType() != Game::NA_BAD && entry->server != oldAddress) else if (entry->server.getType() != Game::NA_BAD && entry->server != oldAddress)
{ {
Node::AddNode(entry->server); Node::Add(entry->server);
Network::SendCommand(entry->server, "getinfo", Utils::Cryptography::Rand::GenerateChallenge()); Network::SendCommand(entry->server, "getinfo", Utils::Cryptography::Rand::GenerateChallenge());
} }

View File

@ -32,7 +32,8 @@ namespace Components
while (!interval.elapsed(15s)) while (!interval.elapsed(15s))
{ {
Utils::Hook::Call<void()>(0x49F0B0)(); // Com_ClientPacketEvent Utils::Hook::Call<void()>(0x49F0B0)(); // Com_ClientPacketEvent
Node::FrameHandler(); Session::RunFrame();
Node::RunFrame();
ServerList::Frame(); ServerList::Frame();
std::this_thread::sleep_for(10ms); std::this_thread::sleep_for(10ms);

View File

@ -14,7 +14,7 @@ namespace Components
{ {
Game::SockadrToNetadr(addr, &this->address); Game::SockadrToNetadr(addr, &this->address);
} }
bool Network::Address::operator==(const Network::Address &obj) bool Network::Address::operator==(const Network::Address &obj) const
{ {
return Game::NET_CompareAdr(this->address, obj.address); return Game::NET_CompareAdr(this->address, obj.address);
} }
@ -67,11 +67,11 @@ namespace Components
{ {
return &this->address; return &this->address;
} }
const char* Network::Address::getCString() const char* Network::Address::getCString() const
{ {
return Game::NET_AdrToString(this->address); return Game::NET_AdrToString(this->address);
} }
std::string Network::Address::getString() std::string Network::Address::getString() const
{ {
return this->getCString(); return this->getCString();
} }
@ -123,18 +123,6 @@ namespace Components
{ {
return (this->getType() != Game::netadrtype_t::NA_BAD); return (this->getType() != Game::netadrtype_t::NA_BAD);
} }
void Network::Address::serialize(Proto::Network::Address* protoAddress)
{
protoAddress->set_ip(this->getIP().full);
protoAddress->set_port(this->getPort() & 0xFFFF);
}
void Network::Address::deserialize(const Proto::Network::Address& protoAddress)
{
this->setIP(protoAddress.ip());
this->setPort(static_cast<uint16_t>(protoAddress.port()));
this->setType(Game::netadrtype_t::NA_IP);
}
void Network::Handle(std::string packet, Utils::Slot<Network::Callback> callback) void Network::Handle(std::string packet, Utils::Slot<Network::Callback> callback)
{ {
Network::PacketHandlers[Utils::String::ToLower(packet)] = callback; Network::PacketHandlers[Utils::String::ToLower(packet)] = callback;

View File

@ -19,9 +19,8 @@ namespace Components
Address(Game::netadr_t addr) : address(addr) {} Address(Game::netadr_t addr) : address(addr) {}
Address(Game::netadr_t* addr) : Address(*addr) {} Address(Game::netadr_t* addr) : Address(*addr) {}
Address(const Address& obj) : address(obj.address) {}; Address(const Address& obj) : address(obj.address) {};
Address(const Proto::Network::Address& addr) { this->deserialize(addr); }; bool operator!=(const Address &obj) const { return !(*this == obj); };
bool operator!=(const Address &obj) { return !(*this == obj); }; bool operator==(const Address &obj) const;
bool operator==(const Address &obj);
void setPort(unsigned short port); void setPort(unsigned short port);
unsigned short getPort(); unsigned short getPort();
@ -37,17 +36,14 @@ namespace Components
void toSockAddr(sockaddr* addr); void toSockAddr(sockaddr* addr);
void toSockAddr(sockaddr_in* addr); void toSockAddr(sockaddr_in* addr);
Game::netadr_t* get(); Game::netadr_t* get();
const char* getCString(); const char* getCString() const;
std::string getString(); std::string getString() const;
bool isLocal(); bool isLocal();
bool isSelf(); bool isSelf();
bool isValid(); bool isValid();
bool isLoopback(); bool isLoopback();
void serialize(Proto::Network::Address* protoAddress);
void deserialize(const Proto::Network::Address& protoAddress);
private: private:
Game::netadr_t address; Game::netadr_t address;
}; };
@ -94,3 +90,12 @@ namespace Components
static void PacketErrorCheck(); static void PacketErrorCheck();
}; };
} }
template <>
struct std::hash<Components::Network::Address>
{
std::size_t operator()(const Components::Network::Address& k) const
{
return (std::hash<std::string>()(k.getString()));
}
};

File diff suppressed because it is too large Load Diff

View File

@ -1,99 +1,60 @@
#pragma once #pragma once
#define NODE_QUERY_INTERVAL 1000 * 60 * 2 // Query nodelist from nodes evry 2 minutes #define NODE_HALFLIFE 3min
#define NODE_QUERY_TIMEOUT 1000 * 30 * 1 // Invalidate nodes after 30 seconds without query response #define NODE_REQUEST_LIMIT 3
#define NODE_INVALID_DELETE 1000 * 60 * 10 // Delete invalidated nodes after 10 minutes
#define NODE_FRAME_QUERY_LIMIT 3 // Limit of nodes to be queried per frame
#define NODE_FRAME_LOCK 60 // Limit of max frames per second
#define NODE_PACKET_LIMIT 111 // Send 111 nodes per synchronization packet
#define NODE_STORE_INTERVAL 1000 * 60* 1 // Store nodes every minute
#define SESSION_TIMEOUT 1000 * 10 // 10 seconds session timeout
#define NODE_IP_LIMIT 15
#define NODE_VERSION 5
namespace Components namespace Components
{ {
class Node : public Component class Node : public Component
{ {
public: public:
class Data
{
public:
uint64_t protocol;
};
class Entry
{
public:
Network::Address address;
Data data;
std::optional<Utils::Time::Point> lastRequest;
std::optional<Utils::Time::Point> lastResponse;
Utils::Time::Point creationPoint;
bool isValid();
bool isDead();
bool requiresRequest();
void sendRequest();
void reset();
};
Node(); Node();
~Node(); ~Node();
bool unitTest() override; static void Add(Network::Address address);
static void RunFrame();
static void SyncNodeList(); static void Synchronize();
static void AddNode(Network::Address address, bool def = false);
static unsigned int GetValidNodeCount();
static void LoadNodeRemotePreset(); static void LoadNodeRemotePreset();
static void FrameHandler();
private: private:
enum EntryState static std::recursive_mutex Mutex;
{ static std::vector<Entry> Nodes;
STATE_UNKNOWN,
STATE_NEGOTIATING,
STATE_VALID,
STATE_INVALID,
};
class NodeEntry static void HandleResponse(Network::Address address, std::string data);
{ static void HandleRequest(Network::Address address, std::string data);
public:
Network::Address address;
std::string challenge;
Utils::Cryptography::ECC::Key publicKey;
EntryState state;
bool registered; // Do we consider this node as registered? static void SendList(Network::Address address);
int lastTime; // Last time we heard anything from the server itself
int lastHeard; // Last time we heard something of the server at all (refs form other nodes)
int lastListQuery; // Last time we got the list of the node
bool def;
// This is only relevant for clients
bool isDedi;
uint32_t protocol;
uint32_t version;
};
class ClientSession
{
public:
Network::Address address;
std::string challenge;
bool valid;
//bool terminated; // Sessions can't explicitly be terminated, they can only timeout
int lastTime;
};
static Utils::Cryptography::ECC::Key SignatureKey;
static std::recursive_mutex NodeMutex;
static std::mutex SessionMutex;
static std::vector<NodeEntry> Nodes;
static std::vector<ClientSession> Sessions;
static void LoadNodes();
static void LoadNodePreset(); static void LoadNodePreset();
static void LoadNodes();
static void StoreNodes(bool force); static void StoreNodes(bool force);
static void PerformRegistration(Network::Address address);
static void SendNodeList(Network::Address address);
static NodeEntry* FindNode(Network::Address address);
static ClientSession* FindSession(Network::Address address);
static void DeleteInvalidNodes();
static void DeleteInvalidSessions();
static unsigned short GetPort(); static unsigned short GetPort();
static const char* GetStateName(EntryState state);
}; };
} }

View File

@ -24,7 +24,7 @@ namespace Components
void Party::Connect(Network::Address target) void Party::Connect(Network::Address target)
{ {
Node::AddNode(target); Node::Add(target);
Party::Container.valid = true; Party::Container.valid = true;
Party::Container.awaitingPlaylist = false; Party::Container.awaitingPlaylist = false;

View File

@ -283,7 +283,7 @@ namespace Components
Network::SendCommand(ServerList::RefreshContainer.host, "getservers", Utils::String::VA("IW4 %i full empty", PROTOCOL)); Network::SendCommand(ServerList::RefreshContainer.host, "getservers", Utils::String::VA("IW4 %i full empty", PROTOCOL));
//Network::SendCommand(ServerList::RefreshContainer.Host, "getservers", "0 full empty"); //Network::SendCommand(ServerList::RefreshContainer.Host, "getservers", "0 full empty");
#else #else
Node::SyncNodeList(); Node::Synchronize();
#endif #endif
} }
else if (ServerList::IsFavouriteList()) else if (ServerList::IsFavouriteList())

View File

@ -0,0 +1,211 @@
#include "STDInclude.hpp"
namespace Components
{
std::recursive_mutex Session::Mutex;
std::unordered_map<Network::Address, Session::Frame> Session::Sessions;
std::unordered_map<Network::Address, std::queue<std::shared_ptr<Session::Packet>>> Session::PacketQueue;
Utils::Cryptography::ECC::Key Session::SignatureKey;
std::map<std::string, Utils::Slot<Network::Callback>> Session::PacketHandlers;
void Session::Send(Network::Address target, std::string command, std::string data)
{
std::lock_guard<std::recursive_mutex> _(Session::Mutex);
auto queue = Session::PacketQueue.find(target);
if (queue == Session::PacketQueue.end())
{
Session::PacketQueue[target] = std::queue<std::shared_ptr<Session::Packet>>();
queue = Session::PacketQueue.find(target);
if (queue == Session::PacketQueue.end()) Logger::Error("Failed to enqueue session packet!\n");
}
std::shared_ptr<Session::Packet> packet = std::make_shared<Session::Packet>();
packet->command = command;
packet->data = data;
packet->tries = 0;
queue->second.push(packet);
}
void Session::Handle(std::string packet, Utils::Slot<Network::Callback> callback)
{
std::lock_guard<std::recursive_mutex> _(Session::Mutex);
Session::PacketHandlers[packet] = callback;
}
void Session::RunFrame()
{
std::lock_guard<std::recursive_mutex> _(Session::Mutex);
for (auto queue = Session::PacketQueue.begin(); queue != Session::PacketQueue.end();)
{
if (queue->second.empty())
{
queue = Session::PacketQueue.erase(queue);
continue;
}
std::shared_ptr<Session::Packet> packet = queue->second.front();
if (!packet->lastTry.has_value() || !packet->tries || (packet->lastTry.has_value() && packet->lastTry->elapsed(SESSION_TIMEOUT)))
{
if (packet->tries <= SESSION_MAX_RETRIES)
{
packet->tries++;
packet->lastTry.emplace(Utils::Time::Point());
Network::SendCommand(queue->first, "sessionSyn");
}
else
{
queue->second.pop();
}
}
++queue;
}
}
Session::Session()
{
Session::SignatureKey = Utils::Cryptography::ECC::GenerateKey(512);
Scheduler::OnFrame(Session::RunFrame);
Network::Handle("sessionSyn", [](Network::Address address, std::string data)
{
std::lock_guard<std::recursive_mutex> _(Session::Mutex);
Session::Frame frame;
frame.challenge = Utils::Cryptography::Rand::GenerateChallenge();
Session::Sessions[address] = frame;
Network::SendCommand(address, "sessionAck", frame.challenge);
});
Network::Handle("sessionAck", [](Network::Address address, std::string data)
{
std::lock_guard<std::recursive_mutex> _(Session::Mutex);
auto queue = Session::PacketQueue.find(address);
if (queue == Session::PacketQueue.end()) return;
if (!queue->second.empty())
{
std::shared_ptr<Session::Packet> packet = queue->second.front();
queue->second.pop();
Proto::Session::Packet dataPacket;
dataPacket.set_publickey(Session::SignatureKey.getPublicKey());
dataPacket.set_signature(Utils::Cryptography::ECC::SignMessage(Session::SignatureKey, data));
dataPacket.set_command(packet->command);
dataPacket.set_data(packet->data);
Network::SendCommand(address, "sessionFin", dataPacket.SerializeAsString());
}
});
Network::Handle("sessionFin", [](Network::Address address, std::string data)
{
std::lock_guard<std::recursive_mutex> _(Session::Mutex);
auto frame = Session::Sessions.find(address);
if (frame == Session::Sessions.end()) return;
std::string challenge = frame->second.challenge;
Session::Sessions.erase(frame);
Proto::Session::Packet dataPacket;
if (!dataPacket.ParseFromString(data)) return;
Utils::Cryptography::ECC::Key publicKey;
publicKey.set(dataPacket.publickey());
if (!Utils::Cryptography::ECC::VerifyMessage(publicKey, challenge, dataPacket.signature())) return;
auto handler = Session::PacketHandlers.find(dataPacket.command());
if (handler == Session::PacketHandlers.end()) return;
handler->second(address, dataPacket.data());
});
}
Session::~Session()
{
std::lock_guard<std::recursive_mutex> _(Session::Mutex);
Session::PacketHandlers.clear();
Session::PacketQueue.clear();
Session::SignatureKey.free();
}
bool Session::unitTest()
{
printf("Testing ECDSA key...");
Utils::Cryptography::ECC::Key key = Utils::Cryptography::ECC::GenerateKey(512);
if (!key.isValid())
{
printf("Error\n");
printf("ECDSA key seems invalid!\n");
return false;
}
printf("Success\n");
printf("Testing 10 valid signatures...");
for (int i = 0; i < 10; ++i)
{
std::string message = Utils::Cryptography::Rand::GenerateChallenge();
std::string signature = Utils::Cryptography::ECC::SignMessage(key, message);
if (!Utils::Cryptography::ECC::VerifyMessage(key, message, signature))
{
printf("Error\n");
printf("Signature for '%s' (%d) was invalid!\n", message.data(), i);
return false;
}
}
printf("Success\n");
printf("Testing 10 invalid signatures...");
for (int i = 0; i < 10; ++i)
{
std::string message = Utils::Cryptography::Rand::GenerateChallenge();
std::string signature = Utils::Cryptography::ECC::SignMessage(key, message);
// Invalidate the message...
++message[Utils::Cryptography::Rand::GenerateInt() % message.size()];
if (Utils::Cryptography::ECC::VerifyMessage(key, message, signature))
{
printf("Error\n");
printf("Signature for '%s' (%d) was valid? What the fuck? That is absolutely impossible...\n", message.data(), i);
return false;
}
}
printf("Success\n");
printf("Testing ECDSA key import...");
std::string pubKey = key.getPublicKey();
std::string message = Utils::Cryptography::Rand::GenerateChallenge();
std::string signature = Utils::Cryptography::ECC::SignMessage(key, message);
Utils::Cryptography::ECC::Key testKey;
testKey.set(pubKey);
if (!Utils::Cryptography::ECC::VerifyMessage(key, message, signature))
{
printf("Error\n");
printf("Verifying signature for message '%s' using imported keys failed!\n", message.data());
return false;
}
printf("Success\n");
return true;
}
}

View File

@ -0,0 +1,48 @@
#pragma once
#define SESSION_TIMEOUT 10s
#define SESSION_MAX_RETRIES 3
#define SESSION_REQUEST_LIMIT 3
namespace Components
{
class Session : public Component
{
public:
class Packet
{
public:
std::string command;
std::string data;
unsigned int tries;
std::optional<Utils::Time::Point> lastTry;
};
class Frame
{
public:
std::string challenge;
Utils::Time::Point creationPoint;
};
Session();
~Session();
bool unitTest() override;
static void Send(Network::Address target, std::string command, std::string data = "");
static void Handle(std::string packet, Utils::Slot<Network::Callback> callback);
static void RunFrame();
private:
static std::recursive_mutex Mutex;
static std::unordered_map<Network::Address, Frame> Sessions;
static std::unordered_map<Network::Address, std::queue<std::shared_ptr<Packet>>> PacketQueue;
static Utils::Cryptography::ECC::Key SignatureKey;
static std::map<std::string, Utils::Slot<Network::Callback>> PacketHandlers;
};
}

View File

@ -1,10 +0,0 @@
syntax = "proto3";
package Proto.Network;
// TODO: Add support for IPv6, once the game supports it (I assume we'll implement it :P)
message Address
{
uint32 ip = 1;
uint32 port = 2; // Actually only 16 bits, but apparently protobuf handles that (https://groups.google.com/d/msg/protobuf/Er39mNGnRWU/x6Srz_GrZPgJ)
}

View File

@ -1,23 +1,16 @@
syntax = "proto3"; syntax = "proto3";
package Proto.Node; package Proto.Node;
import "network.proto";
message Packet
{
bytes challenge = 1;
bytes signature = 2;
bytes publickey = 3;
// The port is used to check if a dedi sends data through a redirected port.
// This usually means the port is not forwarded
uint32 port = 4;
}
message List message List
{ {
bool is_dedi = 1; repeated bytes nodes = 1;
repeated Network.Address address = 2;
uint32 protocol = 3; // The port is used to check if a dedi sends data through a redirected port.
uint32 version = 4; // This usually means the port is not forwarded
uint32 port = 2;
// Additional data
bool isNode = 3;
uint64 protocol = 4;
} }

11
src/Proto/session.proto Normal file
View File

@ -0,0 +1,11 @@
syntax = "proto3";
package Proto.Session;
message Packet
{
bytes signature = 1;
bytes publicKey = 2;
bytes command = 3;
bytes data = 4;
}

View File

@ -89,7 +89,7 @@ template <size_t S> class Sizer { };
#endif #endif
// Protobuf // Protobuf
#include "proto/network.pb.h" #include "proto/session.pb.h"
#include "proto/party.pb.h" #include "proto/party.pb.h"
#include "proto/auth.pb.h" #include "proto/auth.pb.h"
#include "proto/node.pb.h" #include "proto/node.pb.h"

View File

@ -8,6 +8,7 @@ namespace Utils
{ {
Utils::Memory::Allocator allocator; Utils::Memory::Allocator allocator;
unsigned long length = (data.size() * 2); unsigned long length = (data.size() * 2);
if (!length) length = 2;
// Make sure the buffer is large enough // Make sure the buffer is large enough
if (length < 100) length *= 10; if (length < 100) length *= 10;
@ -16,7 +17,6 @@ namespace Utils
if (compress2(reinterpret_cast<Bytef*>(buffer), &length, reinterpret_cast<Bytef*>(const_cast<char*>(data.data())), data.size(), Z_BEST_COMPRESSION) != Z_OK) if (compress2(reinterpret_cast<Bytef*>(buffer), &length, reinterpret_cast<Bytef*>(const_cast<char*>(data.data())), data.size(), Z_BEST_COMPRESSION) != Z_OK)
{ {
Utils::Memory::Free(buffer);
return ""; return "";
} }

View File

@ -13,5 +13,15 @@ namespace Utils
{ {
return ((std::chrono::high_resolution_clock::now() - this->lastPoint) >= nsecs); return ((std::chrono::high_resolution_clock::now() - this->lastPoint) >= nsecs);
} }
std::chrono::high_resolution_clock::duration Point::diff(Point point)
{
return point.lastPoint - this->lastPoint;
}
bool Point::after(Point point)
{
return this->diff(point).count() < 0;
}
} }
} }

View File

@ -6,7 +6,7 @@ namespace Utils
{ {
class Interval class Interval
{ {
private: protected:
std::chrono::high_resolution_clock::time_point lastPoint; std::chrono::high_resolution_clock::time_point lastPoint;
public: public:
@ -15,5 +15,14 @@ namespace Utils
void update(); void update();
bool elapsed(std::chrono::nanoseconds nsecs); bool elapsed(std::chrono::nanoseconds nsecs);
}; };
class Point : public Interval
{
public:
Point() : Interval() {}
std::chrono::high_resolution_clock::duration diff(Point point);
bool after(Point point);
};
} }
} }