diff --git a/src/Components/Loader.cpp b/src/Components/Loader.cpp index c6deb6c7..cc3399a4 100644 --- a/src/Components/Loader.cpp +++ b/src/Components/Loader.cpp @@ -60,6 +60,7 @@ namespace Components Loader::Register(new ModList()); Loader::Register(new Monitor()); Loader::Register(new Network()); + Loader::Register(new Session()); Loader::Register(new Theatre()); //Loader::Register(new ClanTags()); Loader::Register(new Download()); diff --git a/src/Components/Loader.hpp b/src/Components/Loader.hpp index 76b1d8ff..cfcf6a69 100644 --- a/src/Components/Loader.hpp +++ b/src/Components/Loader.hpp @@ -92,6 +92,7 @@ namespace Components #include "Modules/Logger.hpp" #include "Modules/Friends.hpp" #include "Modules/IPCPipe.hpp" +#include "Modules/Session.hpp" #include "Modules/ClanTags.hpp" #include "Modules/Download.hpp" #include "Modules/Playlist.hpp" diff --git a/src/Components/Modules/Friends.cpp b/src/Components/Modules/Friends.cpp index 035291f6..23dc60cf 100644 --- a/src/Components/Modules/Friends.cpp +++ b/src/Components/Modules/Friends.cpp @@ -105,7 +105,7 @@ namespace Components if (entry->server.getType() == Game::NA_LOOPBACK || (entry->server.getType() == Game::NA_IP && entry->server.getIP().full == 0x0100007F)) entry->server.setType(Game::NA_BAD); else if (entry->server.getType() != Game::NA_BAD && entry->server != oldAddress) { - Node::AddNode(entry->server); + Node::Add(entry->server); Network::SendCommand(entry->server, "getinfo", Utils::Cryptography::Rand::GenerateChallenge()); } diff --git a/src/Components/Modules/Monitor.cpp b/src/Components/Modules/Monitor.cpp index ef8e84ef..17e5603c 100644 --- a/src/Components/Modules/Monitor.cpp +++ b/src/Components/Modules/Monitor.cpp @@ -32,7 +32,8 @@ namespace Components while (!interval.elapsed(15s)) { Utils::Hook::Call(0x49F0B0)(); // Com_ClientPacketEvent - Node::FrameHandler(); + Session::RunFrame(); + Node::RunFrame(); ServerList::Frame(); std::this_thread::sleep_for(10ms); diff --git a/src/Components/Modules/Network.cpp b/src/Components/Modules/Network.cpp index 85c4d945..45cd5aec 100644 --- a/src/Components/Modules/Network.cpp +++ b/src/Components/Modules/Network.cpp @@ -14,7 +14,7 @@ namespace Components { Game::SockadrToNetadr(addr, &this->address); } - bool Network::Address::operator==(const Network::Address &obj) + bool Network::Address::operator==(const Network::Address &obj) const { return Game::NET_CompareAdr(this->address, obj.address); } @@ -67,11 +67,11 @@ namespace Components { return &this->address; } - const char* Network::Address::getCString() + const char* Network::Address::getCString() const { return Game::NET_AdrToString(this->address); } - std::string Network::Address::getString() + std::string Network::Address::getString() const { return this->getCString(); } @@ -123,18 +123,6 @@ namespace Components { return (this->getType() != Game::netadrtype_t::NA_BAD); } - void Network::Address::serialize(Proto::Network::Address* protoAddress) - { - protoAddress->set_ip(this->getIP().full); - protoAddress->set_port(this->getPort() & 0xFFFF); - } - void Network::Address::deserialize(const Proto::Network::Address& protoAddress) - { - this->setIP(protoAddress.ip()); - this->setPort(static_cast(protoAddress.port())); - this->setType(Game::netadrtype_t::NA_IP); - } - void Network::Handle(std::string packet, Utils::Slot callback) { Network::PacketHandlers[Utils::String::ToLower(packet)] = callback; diff --git a/src/Components/Modules/Network.hpp b/src/Components/Modules/Network.hpp index 54814e3e..4a2919b2 100644 --- a/src/Components/Modules/Network.hpp +++ b/src/Components/Modules/Network.hpp @@ -19,9 +19,8 @@ namespace Components Address(Game::netadr_t addr) : address(addr) {} Address(Game::netadr_t* addr) : Address(*addr) {} Address(const Address& obj) : address(obj.address) {}; - Address(const Proto::Network::Address& addr) { this->deserialize(addr); }; - bool operator!=(const Address &obj) { return !(*this == obj); }; - bool operator==(const Address &obj); + bool operator!=(const Address &obj) const { return !(*this == obj); }; + bool operator==(const Address &obj) const; void setPort(unsigned short port); unsigned short getPort(); @@ -37,17 +36,14 @@ namespace Components void toSockAddr(sockaddr* addr); void toSockAddr(sockaddr_in* addr); Game::netadr_t* get(); - const char* getCString(); - std::string getString(); + const char* getCString() const; + std::string getString() const; bool isLocal(); bool isSelf(); bool isValid(); bool isLoopback(); - void serialize(Proto::Network::Address* protoAddress); - void deserialize(const Proto::Network::Address& protoAddress); - private: Game::netadr_t address; }; @@ -94,3 +90,12 @@ namespace Components static void PacketErrorCheck(); }; } + +template <> +struct std::hash +{ + std::size_t operator()(const Components::Network::Address& k) const + { + return (std::hash()(k.getString())); + } +}; diff --git a/src/Components/Modules/Node.cpp b/src/Components/Modules/Node.cpp index 7e0374d2..facae0e8 100644 --- a/src/Components/Modules/Node.cpp +++ b/src/Components/Modules/Node.cpp @@ -2,11 +2,50 @@ namespace Components { - std::recursive_mutex Node::NodeMutex; - std::mutex Node::SessionMutex; - Utils::Cryptography::ECC::Key Node::SignatureKey; - std::vector Node::Nodes; - std::vector Node::Sessions; + std::recursive_mutex Node::Mutex; + std::vector Node::Nodes; + + bool Node::Entry::isValid() + { + return (this->lastResponse.has_value() && !this->lastResponse->elapsed(NODE_HALFLIFE * 2)); + } + + bool Node::Entry::isDead() + { + if (!this->lastResponse.has_value()) + { + if (this->lastRequest.has_value() && this->lastRequest->elapsed(NODE_HALFLIFE) && this->creationPoint.elapsed(NODE_HALFLIFE)) + { + return true; + } + } + else if(this->lastResponse->elapsed(NODE_HALFLIFE * 2) && this->lastRequest.has_value() && this->lastRequest->after(*this->lastResponse)) + { + return true; + } + + return false; + } + + bool Node::Entry::requiresRequest() + { + return (!this->isDead() && (!this->lastRequest.has_value() || this->lastRequest->elapsed(NODE_HALFLIFE))); + } + + void Node::Entry::sendRequest() + { + if (!this->lastRequest.has_value()) this->lastRequest.emplace(Utils::Time::Point()); + this->lastRequest->update(); + + Session::Send(this->address, "nodeListRequest"); + Node::SendList(this->address); + } + + void Node::Entry::reset() + { + // this->lastResponse.reset(); // This would invalidate the node, but maybe we don't want that? + this->lastRequest.reset(); + } void Node::LoadNodeRemotePreset() { @@ -18,7 +57,7 @@ namespace Components { Utils::String::Replace(node, "\r", ""); node = Utils::String::Trim(node); - Node::AddNode(node, true); + Node::Add(node); } } @@ -37,9 +76,14 @@ namespace Components if (!defaultNodes.exists() || !list.ParseFromString(Utils::Compression::ZLib::Decompress(defaultNodes.getBuffer()))) return; } - for (int i = 0; i < list.address_size(); ++i) + for (int i = 0; i < list.nodes_size(); ++i) { - Node::AddNode(list.address(i), true); + const std::string& addr = list.nodes(i); + + if (addr.size() == sizeof(sockaddr)) + { + Node::Add(reinterpret_cast(const_cast(addr.data()))); + } } } @@ -49,352 +93,182 @@ namespace Components std::string nodes = Utils::IO::ReadFile("players/nodes.dat"); if (nodes.empty() || !list.ParseFromString(Utils::Compression::ZLib::Decompress(nodes))) return; - for (int i = 0; i < list.address_size(); ++i) + for (int i = 0; i < list.nodes_size(); ++i) { - Node::AddNode(list.address(i)); + const std::string& addr = list.nodes(i); + + if (addr.size() == sizeof(sockaddr)) + { + Node::Add(reinterpret_cast(const_cast(addr.data()))); + } } } + void Node::StoreNodes(bool force) { if (Dedicated::IsEnabled() && Dvar::Var("sv_lanOnly").get()) return; - static int lastStorage = 0; - - // Don't store nodes if the delta is too small and were not forcing it - if (((Game::Sys_Milliseconds() - lastStorage) < NODE_STORE_INTERVAL && !force) || !Node::GetValidNodeCount()) return; - lastStorage = Game::Sys_Milliseconds(); + static Utils::Time::Interval interval; + if (!force && !interval.elapsed(1min)) return; Proto::Node::List list; - // This is obsolete when storing to file. - // However, defining another proto message due to this would be redundant. - //list.set_is_dedi(Dedicated::IsDedicated()); - - std::lock_guard _(Node::NodeMutex); + std::lock_guard _(Node::Mutex); for (auto& node : Node::Nodes) { - if (node.state == Node::STATE_VALID && node.registered) + if (node.isValid()) { - node.address.serialize(list.add_address()); + std::string* str = list.add_nodes(); + + sockaddr addr = node.address.getSockAddr(); + str->append(reinterpret_cast(&addr), sizeof(addr)); } } Utils::IO::WriteFile("players/nodes.dat", Utils::Compression::ZLib::Compress(list.SerializeAsString())); } - Node::NodeEntry* Node::FindNode(Network::Address address) + + void Node::Add(Network::Address address) { - std::lock_guard _(Node::NodeMutex); - - for (auto i = Node::Nodes.begin(); i != Node::Nodes.end(); ++i) - { - if (i->address == address) - { - return &(*i); - } - } - - return nullptr; - } - Node::ClientSession* Node::FindSession(Network::Address address) - { - for (auto i = Node::Sessions.begin(); i != Node::Sessions.end(); ++i) - { - if (i->address == address) - { - return &(*i); - } - } - - return nullptr; - } - - unsigned int Node::GetValidNodeCount() - { - unsigned int count = 0; - std::lock_guard _(Node::NodeMutex); - - for (auto& node : Node::Nodes) - { - if (node.state == Node::STATE_VALID) - { - ++count; - } - } - - return count; - } - - void Node::AddNode(Network::Address address, bool def) - { -#ifdef DEBUG - if (!address.isValid() || address.isSelf()) return; -#else - if (!address.isValid() || address.isLocal() || address.isSelf()) return; +#ifndef DEBUG + if (address.isLocal() || address.isSelf()) return; #endif - std::lock_guard _(Node::NodeMutex); - Node::NodeEntry* existingEntry = Node::FindNode(address); - if (existingEntry) + std::lock_guard _(Node::Mutex); + for (auto& session : Node::Nodes) { - existingEntry->def = false; - existingEntry->lastHeard = Game::Sys_Milliseconds(); + if (session.address == address) return; } - else - { - int count = 0; - for (auto entry : Node::Nodes) - { - if (entry.state != Node::STATE_INVALID && entry.address.getIP().full == address.getIP().full) - { - count++; - } - if (count >= NODE_IP_LIMIT) return; - } + Node::Entry node; + node.address = address; - Node::NodeEntry entry; - - entry.lastHeard = Game::Sys_Milliseconds(); - entry.def = def; - entry.lastTime = 0; - entry.lastListQuery = 0; - entry.registered = false; - entry.state = Node::STATE_UNKNOWN; - entry.address = address; - entry.challenge.clear(); - - Node::Nodes.push_back(entry); - -#if defined(DEBUG) && !defined(DISABLE_NODE_LOG) - Logger::Print("Adding node %s...\n", address.getCString()); -#endif - } + Node::Nodes.push_back(node); } - void Node::SendNodeList(Network::Address address) - { - if (address.isSelf()) return; - - Proto::Node::List list; - list.set_is_dedi(Dedicated::IsEnabled()); - list.set_protocol(PROTOCOL); - list.set_version(NODE_VERSION); - - std::lock_guard _(Node::NodeMutex); - - for (auto& node : Node::Nodes) - { - if (node.state == Node::STATE_VALID && node.registered) - { - node.address.serialize(list.add_address()); - } - - if (list.address_size() >= NODE_PACKET_LIMIT) - { - Network::SendCommand(address, "nodeListResponse", list.SerializeAsString()); - list.clear_address(); - } - } - - // Even if we send an empty list, we have to tell the client about our dedi-status - // If the amount of servers we have modulo the NODE_PACKET_LIMIT equals 0, we will send this request without any servers, so it's obsolete, but meh... - Network::SendCommand(address, "nodeListResponse", list.SerializeAsString()); - } - - void Node::DeleteInvalidSessions() - { - std::lock_guard _(Node::SessionMutex); - for (auto i = Node::Sessions.begin(); i != Node::Sessions.end();) - { - if (i->lastTime <= 0 || (Game::Sys_Milliseconds() - i->lastTime) > SESSION_TIMEOUT) - { - i = Node::Sessions.erase(i); - } - else - { - ++i; - } - } - } - - void Node::DeleteInvalidNodes() - { - std::lock_guard _(Node::NodeMutex); - - for (auto i = Node::Nodes.begin(); i != Node::Nodes.end();) - { - if (i->state == Node::STATE_INVALID && ((Game::Sys_Milliseconds() - i->lastHeard) > NODE_INVALID_DELETE || i->def)) - { -#if defined(DEBUG) && !defined(DISABLE_NODE_LOG) - Logger::Print("Removing invalid node %s\n", i->address.getCString()); -#endif - i = Node::Nodes.erase(i); - } - else - { - ++i; - } - } - } - - void Node::SyncNodeList() - { - std::lock_guard _(Node::NodeMutex); - for (auto& node : Node::Nodes) - { - if (node.state == Node::STATE_VALID && node.registered) - { - node.state = Node::STATE_UNKNOWN; - node.registered = false; - } - } - } - - void Node::PerformRegistration(Network::Address address) - { - Node::NodeEntry* entry = Node::FindNode(address); - if (!entry) return; - - entry->lastTime = Game::Sys_Milliseconds(); - - if (Dedicated::IsEnabled()) - { - entry->challenge = Utils::Cryptography::Rand::GenerateChallenge(); - - Proto::Node::Packet packet; - packet.set_port(Node::GetPort()); - packet.set_challenge(entry->challenge); - -#if defined(DEBUG) && !defined(DISABLE_NODE_LOG) - Logger::Print("Sending registration request to %s\n", entry->address.getCString()); -#endif - Network::SendCommand(entry->address, "nodeRegisterRequest", packet.SerializeAsString()); - } - else - { -#if defined(DEBUG) && !defined(DISABLE_NODE_LOG) - Logger::Print("Sending session request to %s\n", entry->address.getCString()); -#endif - Network::SendCommand(entry->address, "sessionRequest"); - } - } - - void Node::FrameHandler() + void Node::RunFrame() { if (Dedicated::IsEnabled() && Dvar::Var("sv_lanOnly").get()) return; - // Frame limit - static int lastFrame = 0; - if ((Game::Sys_Milliseconds() - lastFrame) < (1000 / NODE_FRAME_LOCK)/* || Game::Sys_Milliseconds() < 5000*/) return; - lastFrame = Game::Sys_Milliseconds(); - - int registerCount = 0; - int listQueryCount = 0; + std::lock_guard _(Node::Mutex); + int sentRequests = 0; + for (auto i = Node::Nodes.begin(); i != Node::Nodes.end();) { - std::lock_guard _(Node::NodeMutex); - for (auto &node : Node::Nodes) + if (i->isDead()) { - // TODO: Decide how to handle nodes that were already registered, but timed out re-registering. - if (node.state == STATE_NEGOTIATING && (Game::Sys_Milliseconds() - node.lastTime) > (NODE_QUERY_TIMEOUT)) - { - node.registered = false; // Definitely unregister here! - node.state = Node::STATE_INVALID; - node.lastHeard = Game::Sys_Milliseconds(); - node.lastTime = Game::Sys_Milliseconds(); + i = Node::Nodes.erase(i); + continue; + } + else if (sentRequests < NODE_REQUEST_LIMIT && i->requiresRequest()) + { + ++sentRequests; + i->sendRequest(); + } -#if defined(DEBUG) && !defined(DISABLE_NODE_LOG) - Logger::Print("Node negotiation timed out. Invalidating %s\n", node.address.getCString()); -#endif - } + ++i; + } + } - if (registerCount < NODE_FRAME_QUERY_LIMIT) - { - // Register when unregistered and in UNKNOWN state (I doubt it's possible to be unregistered and in VALID state) - if (!node.registered && (node.state != Node::STATE_NEGOTIATING && node.state != Node::STATE_INVALID)) - { - ++registerCount; - node.state = Node::STATE_NEGOTIATING; - Node::PerformRegistration(node.address); - } - // Requery invalid nodes within the NODE_QUERY_INTERVAL - // This is required, as a node might crash, which causes it to be invalid. - // If it's restarted though, we wouldn't query it again. + void Node::Synchronize() + { + std::lock_guard _(Node::Mutex); + for (auto& node : Node::Nodes) + { + //if (node.isValid()) // Comment out to simulate 'syncnodes' behaviour + { + node.reset(); + } + } + } - // But wouldn't it send a registration request to us? - // Not sure if the code below is necessary... - // Well, it might be possible that this node doesn't know use anymore. Anyways, just keep that code here... + void Node::HandleResponse(Network::Address address, std::string data) + { + Proto::Node::List list; + if (!list.ParseFromString(data)) return; - // Nvm, this is required for clients, as nodes don't send registration requests to clients. - else if (node.state == STATE_INVALID && (Game::Sys_Milliseconds() - node.lastTime) > NODE_QUERY_INTERVAL) - { - ++registerCount; - Node::PerformRegistration(node.address); - } - } + std::lock_guard _(Node::Mutex); - if (listQueryCount < NODE_FRAME_QUERY_LIMIT) - { - if (node.registered && node.state == Node::STATE_VALID && (!node.lastListQuery || (Game::Sys_Milliseconds() - node.lastListQuery) > NODE_QUERY_INTERVAL)) - { - ++listQueryCount; - node.state = Node::STATE_NEGOTIATING; - node.lastTime = Game::Sys_Milliseconds(); - node.lastListQuery = Game::Sys_Milliseconds(); + for (int i = 0; i < list.nodes_size(); ++i) + { + const std::string& addr = list.nodes(i); - if (Dedicated::IsEnabled()) - { - Network::SendCommand(node.address, "nodeListRequest"); - } - else - { - Network::SendCommand(node.address, "sessionRequest"); - } - } - } + if (addr.size() == sizeof(sockaddr)) + { + Node::Add(reinterpret_cast(const_cast(addr.data()))); } } - static int lastCheck = 0; - if ((Game::Sys_Milliseconds() - lastCheck) < 1000) return; - lastCheck = Game::Sys_Milliseconds(); - - Node::DeleteInvalidSessions(); - Node::DeleteInvalidNodes(); - Node::StoreNodes(false); - } - - const char* Node::GetStateName(EntryState state) - { - switch (state) + if (list.isnode() && (!list.port() || list.port() == address.getPort())) { - case Node::STATE_UNKNOWN: - return "Unknown"; + if (!Dedicated::IsEnabled() && ServerList::IsOnlineList() && list.protocol() == PROTOCOL) + { + ServerList::InsertRequest(address); + } - case Node::STATE_NEGOTIATING: - return "Negotiating"; + for (auto& node : Node::Nodes) + { + if (address == node.address) + { + if (!node.lastResponse.has_value()) node.lastResponse.emplace(Utils::Time::Point()); + node.lastResponse->update(); - case Node::STATE_INVALID: - return "Invalid"; + node.data.protocol = list.protocol(); + return; + } + } - case Node::STATE_VALID: - return "Valid"; + Node::Entry entry; + entry.address = address; + entry.data.protocol = list.protocol(); + entry.lastResponse.emplace(Utils::Time::Point()); + + Node::Nodes.push_back(entry); + } + } + void Node::SendList(Network::Address address) + { + Proto::Node::List list; + list.set_isnode(Dedicated::IsEnabled()); + list.set_protocol(PROTOCOL); + list.set_port(Node::GetPort()); + + std::lock_guard _(Node::Mutex); + + for (auto& node : Node::Nodes) + { + if (node.isValid()) + { + std::string* str = list.add_nodes(); + + sockaddr addr = node.address.getSockAddr(); + str->append(reinterpret_cast(&addr), sizeof(addr)); + } } - return ""; + Session::Send(address, "nodeListResponse", list.SerializeAsString()); + } + + void Node::HandleRequest(Network::Address address, std::string /*data*/) + { + Node::SendList(address); + } + + unsigned short Node::GetPort() + { + if (Dvar::Var("net_natFix").get()) return 0; + return Network::GetPort(); } Node::Node() { - Node::Nodes.clear(); - - // ZoneBuilder doesn't require node stuff if (ZoneBuilder::IsEnabled()) return; Dvar::Register("net_natFix", false, 0, "Fix node registration for certain firewalls/routers"); - // Generate our ECDSA key - Node::SignatureKey = Utils::Cryptography::ECC::GenerateKey(512); + Scheduler::OnFrame(Node::RunFrame); + Session::Handle("nodeListResponse", Node::HandleResponse); + Session::Handle("nodeListRequest", Node::HandleRequest); // Load stored nodes auto loadNodes = []() @@ -406,527 +280,6 @@ namespace Components if (Monitor::IsEnabled()) Network::OnStart(loadNodes); else Dvar::OnInit(loadNodes); - // Send deadline when shutting down - if (Dedicated::IsEnabled()) - { - Scheduler::OnShutdown([]() - { - if (Dvar::Var("sv_lanOnly").get()) return; - - std::string challenge = Utils::Cryptography::Rand::GenerateChallenge(); - - Proto::Node::Packet packet; - packet.set_port(Node::GetPort()); - packet.set_challenge(challenge); - packet.set_signature(Utils::Cryptography::ECC::SignMessage(Node::SignatureKey, challenge)); - - std::lock_guard _(Node::NodeMutex); - for (auto& node : Node::Nodes) - { - Network::SendCommand(node.address, "nodeDeregister", packet.SerializeAsString()); - } - }); - - // This is the handler that accepts registration requests from other nodes - // If you want to get accepted as node, you have to send a request to this handler - Network::Handle("nodeRegisterRequest", [](Network::Address address, std::string data) - { - if (Dvar::Var("sv_lanOnly").get()) return; - - Proto::Node::Packet packet; - if (!packet.ParseFromString(data)) return; - if (packet.challenge().empty()) return; - if (packet.port() && packet.port() != address.getPort()) return; - - // Create a new entry, if we don't already know it - if (!Node::FindNode(address)) - { - Node::AddNode(address); - if (!Node::FindNode(address)) return; - } - - std::lock_guard _(Node::NodeMutex); - Node::NodeEntry* entry = Node::FindNode(address); - -#if defined(DEBUG) && !defined(DISABLE_NODE_LOG) - Logger::Print("Received registration request from %s\n", address.getCString()); -#endif - - std::string signature = Utils::Cryptography::ECC::SignMessage(Node::SignatureKey, packet.challenge()); - std::string challenge = Utils::Cryptography::Rand::GenerateChallenge(); - - // The challenge this client sent is exactly the challenge we stored for this client - // That means this is us, so we're going to ignore us :P - if (packet.challenge() == entry->challenge) - { - entry->lastHeard = Game::Sys_Milliseconds(); - entry->lastTime = Game::Sys_Milliseconds(); - entry->registered = false; - entry->state = Node::STATE_INVALID; - return; - } - - packet.Clear(); - packet.set_challenge(challenge); - packet.set_signature(signature); - packet.set_publickey(Node::SignatureKey.getPublicKey()); - packet.set_port(Node::GetPort()); - - entry->lastTime = Game::Sys_Milliseconds(); - entry->challenge = challenge; - entry->state = Node::STATE_NEGOTIATING; - - Network::SendCommand(address, "nodeRegisterSynchronize", packet.SerializeAsString()); - }); - - Network::Handle("nodeRegisterSynchronize", [](Network::Address address, std::string data) - { - if (Dvar::Var("sv_lanOnly").get()) return; - - std::lock_guard _(Node::NodeMutex); - Node::NodeEntry* entry = Node::FindNode(address); - if (!entry || entry->state != Node::STATE_NEGOTIATING) return; - -#if defined(DEBUG) && !defined(DISABLE_NODE_LOG) - Logger::Print("Received synchronization data for registration from %s!\n", address.getCString()); -#endif - - Proto::Node::Packet packet; - if (!packet.ParseFromString(data)) return; - if (packet.challenge().empty()) return; - if (packet.publickey().empty()) return; - if (packet.signature().empty()) return; - if (packet.port() && packet.port() != address.getPort()) return; - - std::string challenge = packet.challenge(); - std::string publicKey = packet.publickey(); - std::string signature = packet.signature(); - - // Verify signature - entry->publicKey.set(publicKey); - if (!Utils::Cryptography::ECC::VerifyMessage(entry->publicKey, entry->challenge, signature)) - { - Logger::Print("Signature from %s for challenge '%s' is invalid!\n", address.getCString(), entry->challenge.data()); - return; - } - - for (auto& node : Node::Nodes) - { - if (node.publicKey == entry->publicKey) - { - entry->lastTime = Game::Sys_Milliseconds(); - entry->state = Node::STATE_INVALID; - } - } - - // Mark as registered - entry->lastTime = Game::Sys_Milliseconds(); - entry->state = Node::STATE_VALID; - entry->def = false; - entry->registered = true; - -#if defined(DEBUG) && !defined(DISABLE_NODE_LOG) - Logger::Print("Signature from %s for challenge '%s' is valid!\n", address.getCString(), entry->challenge.data()); - Logger::Print("Node %s registered\n", address.getCString()); -#endif - - // Build response - publicKey = Node::SignatureKey.getPublicKey(); - signature = Utils::Cryptography::ECC::SignMessage(Node::SignatureKey, challenge); - - packet.Clear(); - packet.set_signature(signature); - packet.set_publickey(publicKey); - - Network::SendCommand(address, "nodeRegisterAcknowledge", packet.SerializeAsString()); - }); - - Network::Handle("nodeRegisterAcknowledge", [](Network::Address address, std::string data) - { - if (Dvar::Var("sv_lanOnly").get()) return; - - // Ignore requests from nodes we don't know - std::lock_guard _(Node::NodeMutex); - Node::NodeEntry* entry = Node::FindNode(address); - if (!entry || entry->state != Node::STATE_NEGOTIATING) return; - -#if defined(DEBUG) && !defined(DISABLE_NODE_LOG) - Logger::Print("Received acknowledgment from %s\n", address.getCString()); -#endif - - Proto::Node::Packet packet; - if (!packet.ParseFromString(data)) return; - if (packet.signature().empty()) return; - if (packet.publickey().empty()) return; - if (packet.port() && packet.port() != address.getPort()) return; - - std::string publicKey = packet.publickey(); - std::string signature = packet.signature(); - - entry->publicKey.set(publicKey); - - if (Utils::Cryptography::ECC::VerifyMessage(entry->publicKey, entry->challenge, signature)) - { - entry->lastTime = Game::Sys_Milliseconds(); - entry->state = Node::STATE_VALID; - entry->def = false; - entry->registered = true; - -#if defined(DEBUG) && !defined(DISABLE_NODE_LOG) - Logger::Print("Signature from %s for challenge '%s' is valid!\n", address.getCString(), entry->challenge.data()); - Logger::Print("Node %s registered\n", address.getCString()); -#endif - } - else - { -#ifdef DEBUG - Logger::Print("Signature from %s for challenge '%s' is invalid!\n", address.getCString(), entry->challenge.data()); -#endif - } - }); - - Network::Handle("nodeListRequest", [](Network::Address address, std::string data) - { - if (Dvar::Var("sv_lanOnly").get()) return; - - // Check if this is a registered node - bool allowed = false; - - std::lock_guard _(Node::NodeMutex); - Node::NodeEntry* entry = Node::FindNode(address); - if (entry && entry->registered) - { - entry->lastTime = Game::Sys_Milliseconds(); - allowed = true; - } - - // Check if there is any open session - if (!allowed) - { - std::lock_guard __(Node::SessionMutex); - Node::ClientSession* session = Node::FindSession(address); - if (session) - { - session->lastTime = Game::Sys_Milliseconds(); - allowed = session->valid; - } - } - - if (allowed) - { - Node::SendNodeList(address); - } - else - { -#if defined(DEBUG) && !defined(DISABLE_NODE_LOG) - // Unallowed connection - Logger::Print("Node list requested by %s, but no valid session was present!\n", address.getCString()); -#endif - Network::SendCommand(address, "nodeListError"); - } - }); - - Network::Handle("nodeDeregister", [](Network::Address address, std::string data) - { - if (Dvar::Var("sv_lanOnly").get()) return; - - std::lock_guard _(Node::NodeMutex); - Node::NodeEntry* entry = Node::FindNode(address); - if (!entry || !entry->registered) return; - - Proto::Node::Packet packet; - if (!packet.ParseFromString(data)) return; - if (packet.challenge().empty()) return; - if (packet.signature().empty()) return; - if (packet.port() && packet.port() != address.getPort()) return; - - std::string challenge = packet.challenge(); - std::string signature = packet.signature(); - - if (Utils::Cryptography::ECC::VerifyMessage(entry->publicKey, challenge, signature)) - { - entry->def = true; - entry->lastHeard = Game::Sys_Milliseconds(); - entry->lastTime = Game::Sys_Milliseconds(); - entry->registered = false; - entry->state = Node::STATE_INVALID; - -#if defined(DEBUG) && !defined(DISABLE_NODE_LOG) - Logger::Print("Node %s unregistered\n", address.getCString()); -#endif - } - else - { -#if defined(DEBUG) && !defined(DISABLE_NODE_LOG) - Logger::Print("Node %s tried to unregister using an invalid signature!\n", address.getCString()); -#endif - } - }); - - Network::Handle("sessionRequest", [](Network::Address address, std::string data) - { - if (Dvar::Var("sv_lanOnly").get()) return; - - // Search an active session, if we haven't found one, register a template - std::lock_guard _(Node::SessionMutex); - if (!Node::FindSession(address)) - { - Node::ClientSession templateSession; - templateSession.address = address; - Node::Sessions.push_back(templateSession); - } - - // Search our target session (this should not fail!) - Node::ClientSession* session = Node::FindSession(address); - if (!session) return; // Registering template session failed, odd... - -#if defined(DEBUG) && !defined(DISABLE_NODE_LOG) - Logger::Print("Client %s is requesting a new session\n", address.getCString()); -#endif - - // Initialize session data - session->challenge = Utils::Cryptography::Rand::GenerateChallenge(); - session->lastTime = Game::Sys_Milliseconds(); - session->valid = false; - - Network::SendCommand(address, "sessionInitialize", session->challenge); - }); - - Network::Handle("sessionSynchronize", [](Network::Address address, std::string data) - { - if (Dvar::Var("sv_lanOnly").get()) return; - - // Return if we don't have a session for this address - std::lock_guard _(Node::SessionMutex); - Node::ClientSession* session = Node::FindSession(address); - if (!session || session->valid) return; - - if (session->challenge == data) - { -#if defined(DEBUG) && !defined(DISABLE_NODE_LOG) - Logger::Print("Session for %s validated.\n", address.getCString()); -#endif - session->valid = true; - Network::SendCommand(address, "sessionAcknowledge"); - } - else - { - session->lastTime = -1; -#if defined(DEBUG) && !defined(DISABLE_NODE_LOG) - Logger::Print("Challenge mismatch. Validating session for %s failed.\n", address.getCString()); -#endif - } - }); - } - else - { - Network::Handle("sessionInitialize", [](Network::Address address, std::string data) - { - std::lock_guard _(Node::NodeMutex); - Node::NodeEntry* entry = Node::FindNode(address); - if (!entry) return; - -#if defined(DEBUG) && !defined(DISABLE_NODE_LOG) - Logger::Print("Session initialization received from %s. Synchronizing...\n", address.getCString()); -#endif - - entry->lastTime = Game::Sys_Milliseconds(); - Network::SendCommand(address, "sessionSynchronize", data); - }); - - Network::Handle("sessionAcknowledge", [](Network::Address address, std::string data) - { - std::lock_guard _(Node::NodeMutex); - Node::NodeEntry* entry = Node::FindNode(address); - if (!entry) return; - - entry->state = Node::STATE_VALID; - entry->def = false; - entry->registered = true; - entry->lastTime = Game::Sys_Milliseconds(); - -#if defined(DEBUG) && !defined(DISABLE_NODE_LOG) - Logger::Print("Session acknowledged by %s, synchronizing node list...\n", address.getCString()); -#endif - Network::SendCommand(address, "nodeListRequest"); - Node::SendNodeList(address); - }); - } - - Network::Handle("nodeListResponse", [](Network::Address address, std::string data) - { - Proto::Node::List list; - std::lock_guard _(Node::NodeMutex); - - if (data.empty() || !list.ParseFromString(data)) - { -#if defined(DEBUG) && !defined(DISABLE_NODE_LOG) - Logger::Print("Received invalid node list from %s!\n", address.getCString()); -#endif - return; - } - - Node::NodeEntry* entry = Node::FindNode(address); - if (entry) - { - if (entry->registered) - { -#if defined(DEBUG) && !defined(DISABLE_NODE_LOG) - Logger::Print("Received valid node list with %i entries from %s\n", list.address_size(), address.getCString()); -#endif - - entry->isDedi = list.is_dedi(); - entry->protocol = list.protocol(); - entry->version = list.version(); - entry->state = Node::STATE_VALID; - entry->def = false; - entry->lastTime = Game::Sys_Milliseconds(); - -#ifndef DEBUG - // Block old versions - if (entry->version < NODE_VERSION) - { - entry->state = Node::STATE_INVALID; - return; - } -#endif - - if (!Dedicated::IsEnabled() && entry->isDedi && ServerList::IsOnlineList() && entry->protocol == PROTOCOL) - { - ServerList::InsertRequest(entry->address); - } - - for (int i = 0; i < list.address_size(); ++i) - { - Network::Address _addr(list.address(i)); - - // Version 0 sends port in the wrong byte order! - if (list.version() == 0) - { - _addr.setPort(ntohs(_addr.getPort())); - } - -// if (!Node::FindNode(_addr) && _addr.GetPort() >= 1024 && _addr.GetPort() - 20 < 1024) -// { -// std::string a1 = _addr.getString(); -// std::string a2 = address.getString(); -// Logger::Print("Received weird address %s from %s\n", a1.data(), a2.data()); -// } - - Node::AddNode(_addr); - } - } - } - else - { - //Node::AddNode(address); - Node::ClientSession* session = Node::FindSession(address); - if (session && session->valid) - { - session->lastTime = Game::Sys_Milliseconds(); - - for (int i = 0; i < list.address_size(); ++i) - { - Node::AddNode(list.address(i)); - } - } - } - }); - - // If we receive that response, our request was not permitted - // So we either have to register as node, or register a remote session - Network::Handle("nodeListError", [](Network::Address address, std::string data) - { - if (Dedicated::IsEnabled()) - { - std::lock_guard _(Node::NodeMutex); - Node::NodeEntry* entry = Node::FindNode(address); - if (entry) - { - // Set to unregistered to perform registration later on - entry->lastTime = Game::Sys_Milliseconds(); - entry->registered = false; - entry->state = Node::STATE_UNKNOWN; - } - else - { - // Add as new entry to perform registration - Node::AddNode(address); - } - } - }); - - Command::Add("listnodes", [](Command::Params*) - { - Logger::Print("Nodes: %d (%d)\n", Node::Nodes.size(), Node::GetValidNodeCount()); - - std::lock_guard _(Node::NodeMutex); - for (auto& node : Node::Nodes) - { - Logger::Print("%s\t(%s)\n", node.address.getCString(), Node::GetStateName(node.state)); - } - }); - - Command::Add("nodeinfo", [](Command::Params*) - { - std::lock_guard _(Node::NodeMutex); - unsigned int valid = 0, invalid = 0, negotiating = 0, total = Node::Nodes.size(); - - for (auto& node : Node::Nodes) - { - valid += static_cast(node.state == Node::STATE_VALID); - invalid += static_cast(node.state == Node::STATE_INVALID); - negotiating += static_cast(node.state == Node::STATE_NEGOTIATING); - } - - Logger::Print("Total: %d\nValid: %d\nNegotiating: %d\nInvalid: %d\n\n", total, valid, negotiating, invalid); - }); - - Command::Add("addnode", [](Command::Params* params) - { - if (params->length() < 2) return; - - Network::Address address(params->get(1)); - Node::AddNode(address); - - std::lock_guard _(Node::NodeMutex); - Node::NodeEntry* entry = Node::FindNode(address); - if (entry) - { - entry->state = Node::STATE_UNKNOWN; - entry->registered = false; - } - }); - - Command::Add("syncnodes", [](Command::Params*) - { - Logger::Print("Resynchronizing nodes...\n"); - - static bool threadRunning = false; - - if (!threadRunning) - { - threadRunning = true; - std::thread([]() - { - Node::LoadNodeRemotePreset(); - threadRunning = false; - }).detach(); - } - - std::lock_guard _(Node::NodeMutex); - for (auto& node : Node::Nodes) - { - node.state = Node::STATE_UNKNOWN; - node.registered = false; - node.lastTime = 0; - node.lastListQuery = 0; - } - }); - - // Install frame handlers - Scheduler::OnFrame(Node::FrameHandler); - Network::OnStart([]() { std::thread([]() @@ -935,117 +288,29 @@ namespace Components }).detach(); }); - if (Dedicated::IsEnabled()) + Command::Add("listnodes", [](Command::Params*) { - Network::Handle("getServersRequest", [](Network::Address target, std::string) + Logger::Print("Nodes: %d\n", Node::Nodes.size()); + + std::lock_guard _(Node::Mutex); + for (auto& node : Node::Nodes) { - std::string data; + Logger::Print("%s\t(%s)\n", node.address.getCString(), node.isValid() ? "Valid" : "Invalid"); + } + }); - { - std::lock_guard _(Node::NodeMutex); - for (auto& node : Node::Nodes) - { - if (node.state == Node::STATE_VALID && node.isDedi) - { - Game::netIP_t ip = node.address.getIP(); - unsigned short port = htons(node.address.getPort()); - data.append(reinterpret_cast(&ip.full), 4); - data.append(reinterpret_cast(&port), 2); - data.append("\\"); - } - } + Command::Add("addnode", [](Command::Params* params) + { + if (params->length() < 2) return; + Node::Add({ params->get(1) }); + }); - data.append("EOT"); - } - - Network::SendCommand(target, "getServersResponse", data); - }); - } } Node::~Node() { - Node::SignatureKey.free(); + std::lock_guard _(Node::Mutex); Node::StoreNodes(true); - - { - std::lock_guard _(Node::NodeMutex); - std::lock_guard __(Node::SessionMutex); - Node::Nodes.clear(); - Node::Sessions.clear(); - } - } - - unsigned short Node::GetPort() - { - if (Dvar::Var("net_natFix").get()) return 0; - return Network::GetPort(); - } - - bool Node::unitTest() - { - printf("Testing ECDSA key..."); - - if (!Node::SignatureKey.isValid()) - { - printf("Error\n"); - printf("ECDSA key seems invalid!\n"); - return false; - } - - printf("Success\n"); - printf("Testing 10 valid signatures..."); - - for (int i = 0; i < 10; ++i) - { - std::string message = Utils::Cryptography::Rand::GenerateChallenge(); - std::string signature = Utils::Cryptography::ECC::SignMessage(Node::SignatureKey, message); - - if (!Utils::Cryptography::ECC::VerifyMessage(Node::SignatureKey, message, signature)) - { - printf("Error\n"); - printf("Signature for '%s' (%d) was invalid!\n", message.data(), i); - return false; - } - } - - printf("Success\n"); - printf("Testing 10 invalid signatures..."); - - for (int i = 0; i < 10; ++i) - { - std::string message = Utils::Cryptography::Rand::GenerateChallenge(); - std::string signature = Utils::Cryptography::ECC::SignMessage(Node::SignatureKey, message); - - // Invalidate the message... - ++message[Utils::Cryptography::Rand::GenerateInt() % message.size()]; - - if (Utils::Cryptography::ECC::VerifyMessage(Node::SignatureKey, message, signature)) - { - printf("Error\n"); - printf("Signature for '%s' (%d) was valid? What the fuck? That is absolutely impossible...\n", message.data(), i); - return false; - } - } - - printf("Success\n"); - printf("Testing ECDSA key import..."); - - std::string pubKey = Node::SignatureKey.getPublicKey(); - std::string message = Utils::Cryptography::Rand::GenerateChallenge(); - std::string signature = Utils::Cryptography::ECC::SignMessage(Node::SignatureKey, message); - - Utils::Cryptography::ECC::Key testKey; - testKey.set(pubKey); - - if (!Utils::Cryptography::ECC::VerifyMessage(Node::SignatureKey, message, signature)) - { - printf("Error\n"); - printf("Verifying signature for message '%s' using imported keys failed!\n", message.data()); - return false; - } - - printf("Success\n"); - return true; + Node::Nodes.clear(); } } diff --git a/src/Components/Modules/Node.hpp b/src/Components/Modules/Node.hpp index e4a01f19..44502488 100644 --- a/src/Components/Modules/Node.hpp +++ b/src/Components/Modules/Node.hpp @@ -1,99 +1,60 @@ #pragma once -#define NODE_QUERY_INTERVAL 1000 * 60 * 2 // Query nodelist from nodes evry 2 minutes -#define NODE_QUERY_TIMEOUT 1000 * 30 * 1 // Invalidate nodes after 30 seconds without query response -#define NODE_INVALID_DELETE 1000 * 60 * 10 // Delete invalidated nodes after 10 minutes -#define NODE_FRAME_QUERY_LIMIT 3 // Limit of nodes to be queried per frame -#define NODE_FRAME_LOCK 60 // Limit of max frames per second -#define NODE_PACKET_LIMIT 111 // Send 111 nodes per synchronization packet -#define NODE_STORE_INTERVAL 1000 * 60* 1 // Store nodes every minute -#define SESSION_TIMEOUT 1000 * 10 // 10 seconds session timeout - -#define NODE_IP_LIMIT 15 - -#define NODE_VERSION 5 +#define NODE_HALFLIFE 3min +#define NODE_REQUEST_LIMIT 3 namespace Components { class Node : public Component { public: + class Data + { + public: + uint64_t protocol; + }; + + class Entry + { + public: + Network::Address address; + Data data; + + std::optional lastRequest; + std::optional lastResponse; + Utils::Time::Point creationPoint; + + bool isValid(); + bool isDead(); + + bool requiresRequest(); + void sendRequest(); + + void reset(); + }; + Node(); ~Node(); - bool unitTest() override; - - static void SyncNodeList(); - static void AddNode(Network::Address address, bool def = false); - - static unsigned int GetValidNodeCount(); + static void Add(Network::Address address); + static void RunFrame(); + static void Synchronize(); static void LoadNodeRemotePreset(); - static void FrameHandler(); - private: - enum EntryState - { - STATE_UNKNOWN, - STATE_NEGOTIATING, - STATE_VALID, - STATE_INVALID, - }; + static std::recursive_mutex Mutex; + static std::vector Nodes; - class NodeEntry - { - public: - Network::Address address; - std::string challenge; - Utils::Cryptography::ECC::Key publicKey; - EntryState state; + static void HandleResponse(Network::Address address, std::string data); + static void HandleRequest(Network::Address address, std::string data); - bool registered; // Do we consider this node as registered? + static void SendList(Network::Address address); - int lastTime; // Last time we heard anything from the server itself - int lastHeard; // Last time we heard something of the server at all (refs form other nodes) - int lastListQuery; // Last time we got the list of the node - - bool def; - - // This is only relevant for clients - bool isDedi; - uint32_t protocol; - uint32_t version; - }; - - class ClientSession - { - public: - Network::Address address; - std::string challenge; - bool valid; - //bool terminated; // Sessions can't explicitly be terminated, they can only timeout - int lastTime; - }; - - static Utils::Cryptography::ECC::Key SignatureKey; - - static std::recursive_mutex NodeMutex; - static std::mutex SessionMutex; - static std::vector Nodes; - static std::vector Sessions; - - static void LoadNodes(); static void LoadNodePreset(); + static void LoadNodes(); static void StoreNodes(bool force); - static void PerformRegistration(Network::Address address); - static void SendNodeList(Network::Address address); - static NodeEntry* FindNode(Network::Address address); - static ClientSession* FindSession(Network::Address address); - - static void DeleteInvalidNodes(); - static void DeleteInvalidSessions(); - static unsigned short GetPort(); - - static const char* GetStateName(EntryState state); }; } diff --git a/src/Components/Modules/Party.cpp b/src/Components/Modules/Party.cpp index 2812df25..2059f435 100644 --- a/src/Components/Modules/Party.cpp +++ b/src/Components/Modules/Party.cpp @@ -24,7 +24,7 @@ namespace Components void Party::Connect(Network::Address target) { - Node::AddNode(target); + Node::Add(target); Party::Container.valid = true; Party::Container.awaitingPlaylist = false; diff --git a/src/Components/Modules/ServerList.cpp b/src/Components/Modules/ServerList.cpp index 3df95f4e..235a94c5 100644 --- a/src/Components/Modules/ServerList.cpp +++ b/src/Components/Modules/ServerList.cpp @@ -283,7 +283,7 @@ namespace Components Network::SendCommand(ServerList::RefreshContainer.host, "getservers", Utils::String::VA("IW4 %i full empty", PROTOCOL)); //Network::SendCommand(ServerList::RefreshContainer.Host, "getservers", "0 full empty"); #else - Node::SyncNodeList(); + Node::Synchronize(); #endif } else if (ServerList::IsFavouriteList()) diff --git a/src/Components/Modules/Session.cpp b/src/Components/Modules/Session.cpp new file mode 100644 index 00000000..721a741b --- /dev/null +++ b/src/Components/Modules/Session.cpp @@ -0,0 +1,211 @@ +#include "STDInclude.hpp" + +namespace Components +{ + std::recursive_mutex Session::Mutex; + std::unordered_map Session::Sessions; + std::unordered_map>> Session::PacketQueue; + + Utils::Cryptography::ECC::Key Session::SignatureKey; + + std::map> Session::PacketHandlers; + + void Session::Send(Network::Address target, std::string command, std::string data) + { + std::lock_guard _(Session::Mutex); + + auto queue = Session::PacketQueue.find(target); + if (queue == Session::PacketQueue.end()) + { + Session::PacketQueue[target] = std::queue>(); + queue = Session::PacketQueue.find(target); + if (queue == Session::PacketQueue.end()) Logger::Error("Failed to enqueue session packet!\n"); + } + + std::shared_ptr packet = std::make_shared(); + packet->command = command; + packet->data = data; + packet->tries = 0; + + queue->second.push(packet); + } + + void Session::Handle(std::string packet, Utils::Slot callback) + { + std::lock_guard _(Session::Mutex); + Session::PacketHandlers[packet] = callback; + } + + void Session::RunFrame() + { + std::lock_guard _(Session::Mutex); + + for (auto queue = Session::PacketQueue.begin(); queue != Session::PacketQueue.end();) + { + if (queue->second.empty()) + { + queue = Session::PacketQueue.erase(queue); + continue; + } + + std::shared_ptr packet = queue->second.front(); + if (!packet->lastTry.has_value() || !packet->tries || (packet->lastTry.has_value() && packet->lastTry->elapsed(SESSION_TIMEOUT))) + { + if (packet->tries <= SESSION_MAX_RETRIES) + { + packet->tries++; + packet->lastTry.emplace(Utils::Time::Point()); + + Network::SendCommand(queue->first, "sessionSyn"); + } + else + { + queue->second.pop(); + } + } + + ++queue; + } + } + + Session::Session() + { + Session::SignatureKey = Utils::Cryptography::ECC::GenerateKey(512); + + Scheduler::OnFrame(Session::RunFrame); + + Network::Handle("sessionSyn", [](Network::Address address, std::string data) + { + std::lock_guard _(Session::Mutex); + + Session::Frame frame; + frame.challenge = Utils::Cryptography::Rand::GenerateChallenge(); + Session::Sessions[address] = frame; + + Network::SendCommand(address, "sessionAck", frame.challenge); + }); + + Network::Handle("sessionAck", [](Network::Address address, std::string data) + { + std::lock_guard _(Session::Mutex); + + auto queue = Session::PacketQueue.find(address); + if (queue == Session::PacketQueue.end()) return; + + if (!queue->second.empty()) + { + std::shared_ptr packet = queue->second.front(); + queue->second.pop(); + + Proto::Session::Packet dataPacket; + dataPacket.set_publickey(Session::SignatureKey.getPublicKey()); + dataPacket.set_signature(Utils::Cryptography::ECC::SignMessage(Session::SignatureKey, data)); + dataPacket.set_command(packet->command); + dataPacket.set_data(packet->data); + + Network::SendCommand(address, "sessionFin", dataPacket.SerializeAsString()); + } + }); + + Network::Handle("sessionFin", [](Network::Address address, std::string data) + { + std::lock_guard _(Session::Mutex); + + auto frame = Session::Sessions.find(address); + if (frame == Session::Sessions.end()) return; + + std::string challenge = frame->second.challenge; + Session::Sessions.erase(frame); + + Proto::Session::Packet dataPacket; + if (!dataPacket.ParseFromString(data)) return; + + Utils::Cryptography::ECC::Key publicKey; + publicKey.set(dataPacket.publickey()); + + if (!Utils::Cryptography::ECC::VerifyMessage(publicKey, challenge, dataPacket.signature())) return; + + auto handler = Session::PacketHandlers.find(dataPacket.command()); + if (handler == Session::PacketHandlers.end()) return; + + handler->second(address, dataPacket.data()); + }); + } + + Session::~Session() + { + std::lock_guard _(Session::Mutex); + Session::PacketHandlers.clear(); + Session::PacketQueue.clear(); + + Session::SignatureKey.free(); + } + + bool Session::unitTest() + { + printf("Testing ECDSA key..."); + Utils::Cryptography::ECC::Key key = Utils::Cryptography::ECC::GenerateKey(512); + + if (!key.isValid()) + { + printf("Error\n"); + printf("ECDSA key seems invalid!\n"); + return false; + } + + printf("Success\n"); + printf("Testing 10 valid signatures..."); + + for (int i = 0; i < 10; ++i) + { + std::string message = Utils::Cryptography::Rand::GenerateChallenge(); + std::string signature = Utils::Cryptography::ECC::SignMessage(key, message); + + if (!Utils::Cryptography::ECC::VerifyMessage(key, message, signature)) + { + printf("Error\n"); + printf("Signature for '%s' (%d) was invalid!\n", message.data(), i); + return false; + } + } + + printf("Success\n"); + printf("Testing 10 invalid signatures..."); + + for (int i = 0; i < 10; ++i) + { + std::string message = Utils::Cryptography::Rand::GenerateChallenge(); + std::string signature = Utils::Cryptography::ECC::SignMessage(key, message); + + // Invalidate the message... + ++message[Utils::Cryptography::Rand::GenerateInt() % message.size()]; + + if (Utils::Cryptography::ECC::VerifyMessage(key, message, signature)) + { + printf("Error\n"); + printf("Signature for '%s' (%d) was valid? What the fuck? That is absolutely impossible...\n", message.data(), i); + return false; + } + } + + printf("Success\n"); + printf("Testing ECDSA key import..."); + + std::string pubKey = key.getPublicKey(); + std::string message = Utils::Cryptography::Rand::GenerateChallenge(); + std::string signature = Utils::Cryptography::ECC::SignMessage(key, message); + + Utils::Cryptography::ECC::Key testKey; + testKey.set(pubKey); + + if (!Utils::Cryptography::ECC::VerifyMessage(key, message, signature)) + { + printf("Error\n"); + printf("Verifying signature for message '%s' using imported keys failed!\n", message.data()); + return false; + } + + printf("Success\n"); + return true; + } +} diff --git a/src/Components/Modules/Session.hpp b/src/Components/Modules/Session.hpp new file mode 100644 index 00000000..c5579716 --- /dev/null +++ b/src/Components/Modules/Session.hpp @@ -0,0 +1,48 @@ +#pragma once + +#define SESSION_TIMEOUT 10s +#define SESSION_MAX_RETRIES 3 +#define SESSION_REQUEST_LIMIT 3 + +namespace Components +{ + class Session : public Component + { + public: + class Packet + { + public: + std::string command; + std::string data; + + unsigned int tries; + std::optional lastTry; + }; + + class Frame + { + public: + std::string challenge; + Utils::Time::Point creationPoint; + }; + + Session(); + ~Session(); + + bool unitTest() override; + + static void Send(Network::Address target, std::string command, std::string data = ""); + static void Handle(std::string packet, Utils::Slot callback); + + static void RunFrame(); + + private: + static std::recursive_mutex Mutex; + static std::unordered_map Sessions; + static std::unordered_map>> PacketQueue; + + static Utils::Cryptography::ECC::Key SignatureKey; + + static std::map> PacketHandlers; + }; +} diff --git a/src/Proto/network.proto b/src/Proto/network.proto deleted file mode 100644 index 22f66543..00000000 --- a/src/Proto/network.proto +++ /dev/null @@ -1,10 +0,0 @@ -syntax = "proto3"; - -package Proto.Network; - -// TODO: Add support for IPv6, once the game supports it (I assume we'll implement it :P) -message Address -{ - uint32 ip = 1; - uint32 port = 2; // Actually only 16 bits, but apparently protobuf handles that (https://groups.google.com/d/msg/protobuf/Er39mNGnRWU/x6Srz_GrZPgJ) -} diff --git a/src/Proto/node.proto b/src/Proto/node.proto index ba1168a0..c26b8165 100644 --- a/src/Proto/node.proto +++ b/src/Proto/node.proto @@ -1,23 +1,16 @@ syntax = "proto3"; package Proto.Node; -import "network.proto"; - -message Packet -{ - bytes challenge = 1; - bytes signature = 2; - bytes publickey = 3; - - // The port is used to check if a dedi sends data through a redirected port. - // This usually means the port is not forwarded - uint32 port = 4; -} message List { - bool is_dedi = 1; - repeated Network.Address address = 2; - uint32 protocol = 3; - uint32 version = 4; + repeated bytes nodes = 1; + + // The port is used to check if a dedi sends data through a redirected port. + // This usually means the port is not forwarded + uint32 port = 2; + + // Additional data + bool isNode = 3; + uint64 protocol = 4; } diff --git a/src/Proto/session.proto b/src/Proto/session.proto new file mode 100644 index 00000000..8b786969 --- /dev/null +++ b/src/Proto/session.proto @@ -0,0 +1,11 @@ +syntax = "proto3"; + +package Proto.Session; + +message Packet +{ + bytes signature = 1; + bytes publicKey = 2; + bytes command = 3; + bytes data = 4; +} \ No newline at end of file diff --git a/src/STDInclude.hpp b/src/STDInclude.hpp index 32e58ae1..a0c9b0cf 100644 --- a/src/STDInclude.hpp +++ b/src/STDInclude.hpp @@ -89,7 +89,7 @@ template class Sizer { }; #endif // Protobuf -#include "proto/network.pb.h" +#include "proto/session.pb.h" #include "proto/party.pb.h" #include "proto/auth.pb.h" #include "proto/node.pb.h" diff --git a/src/Utils/Compression.cpp b/src/Utils/Compression.cpp index 8fe48e9c..c35837b6 100644 --- a/src/Utils/Compression.cpp +++ b/src/Utils/Compression.cpp @@ -8,6 +8,7 @@ namespace Utils { Utils::Memory::Allocator allocator; unsigned long length = (data.size() * 2); + if (!length) length = 2; // Make sure the buffer is large enough if (length < 100) length *= 10; @@ -16,7 +17,6 @@ namespace Utils if (compress2(reinterpret_cast(buffer), &length, reinterpret_cast(const_cast(data.data())), data.size(), Z_BEST_COMPRESSION) != Z_OK) { - Utils::Memory::Free(buffer); return ""; } diff --git a/src/Utils/Time.cpp b/src/Utils/Time.cpp index 5c62d45d..5dd115c5 100644 --- a/src/Utils/Time.cpp +++ b/src/Utils/Time.cpp @@ -13,5 +13,15 @@ namespace Utils { return ((std::chrono::high_resolution_clock::now() - this->lastPoint) >= nsecs); } + + std::chrono::high_resolution_clock::duration Point::diff(Point point) + { + return point.lastPoint - this->lastPoint; + } + + bool Point::after(Point point) + { + return this->diff(point).count() < 0; + } } } diff --git a/src/Utils/Time.hpp b/src/Utils/Time.hpp index 19da7e72..c2da4f6b 100644 --- a/src/Utils/Time.hpp +++ b/src/Utils/Time.hpp @@ -6,7 +6,7 @@ namespace Utils { class Interval { - private: + protected: std::chrono::high_resolution_clock::time_point lastPoint; public: @@ -15,5 +15,14 @@ namespace Utils void update(); bool elapsed(std::chrono::nanoseconds nsecs); }; + + class Point : public Interval + { + public: + Point() : Interval() {} + + std::chrono::high_resolution_clock::duration diff(Point point); + bool after(Point point); + }; } }