From 87c239a1dc7beb2a93f5256a283a38b0548c9e7b Mon Sep 17 00:00:00 2001 From: momo5502 Date: Mon, 8 Feb 2016 18:43:31 +0100 Subject: [PATCH] Fix cryptography memory leaks :P --- src/Components/Modules/Node.cpp | 7 +++ src/Components/Modules/Node.hpp | 3 ++ src/Utils/Cryptography.cpp | 5 +- src/Utils/Cryptography.hpp | 86 +++++++++++++++++++++++++++------ src/Utils/Utils.cpp | 16 ++++++ src/Utils/Utils.hpp | 2 + 6 files changed, 100 insertions(+), 19 deletions(-) diff --git a/src/Components/Modules/Node.cpp b/src/Components/Modules/Node.cpp index 4a01f02e..b522cd10 100644 --- a/src/Components/Modules/Node.cpp +++ b/src/Components/Modules/Node.cpp @@ -2,6 +2,8 @@ namespace Components { + Utils::Cryptography::ECDSA::Key Node::SignatureKey; + std::vector Node::Nodes; std::vector Node::Dedis; @@ -396,6 +398,9 @@ namespace Components // ZoneBuilder doesn't require node stuff if (ZoneBuilder::IsEnabled()) return; + // Generate our ECDSA key + Node::SignatureKey = Utils::Cryptography::ECDSA::GenerateKey(512); + Dvar::OnInit([] () { Node::Dedis.clear(); @@ -570,6 +575,8 @@ namespace Components Node::~Node() { + Node::SignatureKey.Free(); + Node::StoreNodes(true); Node::Nodes.clear(); Node::Dedis.clear(); diff --git a/src/Components/Modules/Node.hpp b/src/Components/Modules/Node.hpp index 07ca433d..e0a34231 100644 --- a/src/Components/Modules/Node.hpp +++ b/src/Components/Modules/Node.hpp @@ -44,6 +44,7 @@ namespace Components struct NodeEntry { Network::Address address; + Utils::Cryptography::ECDSA::Key publicKey; EntryState state; int lastTime; // Last time we heard anything from the server itself int lastHeartbeat; // Last time we got a heartbeat from it @@ -84,6 +85,8 @@ namespace Components }; #pragma pack(pop) + static Utils::Cryptography::ECDSA::Key SignatureKey; + static std::vector Nodes; static std::vector Dedis; diff --git a/src/Utils/Cryptography.cpp b/src/Utils/Cryptography.cpp index 868d06ff..66dddc24 100644 --- a/src/Utils/Cryptography.cpp +++ b/src/Utils/Cryptography.cpp @@ -30,7 +30,6 @@ namespace Utils ECDSA::Key key; register_prng(&sprng_desc); - register_hash(&sha1_desc); ltc_mp = ltm_desc; @@ -41,7 +40,7 @@ namespace Utils std::string ECDSA::SignMessage(Key key, std::string message) { - uint8_t buffer[0x200]; // Default size is 512 + uint8_t buffer[512]; DWORD length = sizeof(buffer); register_prng(&sprng_desc); @@ -81,7 +80,7 @@ namespace Utils std::string RSA::SignMessage(RSA::Key key, std::string message) { - uint8_t buffer[0x200]; // Default size is 512 + uint8_t buffer[512]; DWORD length = sizeof(buffer); register_prng(&sprng_desc); diff --git a/src/Utils/Cryptography.hpp b/src/Utils/Cryptography.hpp index bd0ea53a..ae36fad8 100644 --- a/src/Utils/Cryptography.hpp +++ b/src/Utils/Cryptography.hpp @@ -14,21 +14,33 @@ namespace Utils class Key { public: - Key() { ZeroMemory(&this->KeyStorage, sizeof(this->KeyStorage)); }; - Key(ecc_key* key) : Key(*key) {}; - Key(ecc_key key) : KeyStorage(key) {}; - Key(const Key& obj) : KeyStorage(obj.KeyStorage) {}; + Key() : KeyStorage(new ecc_key) + { + ZeroMemory(this->KeyStorage.get(), sizeof(*this->KeyStorage.get())); + }; + Key(ecc_key* key) : Key() { std::memmove(this->KeyStorage.get(), key, sizeof(*key)); }; + Key(ecc_key key) : Key(&key) {}; + ~Key() + { + if (this->KeyStorage.use_count() <= 1) + { + this->Free(); + } + }; - ~Key() {} + bool IsValid() + { + return (!Utils::MemIsSet(this->KeyStorage.get(), 0, sizeof(*this->KeyStorage.get()))); + } ecc_key* GetKeyPtr() { - return &this->KeyStorage; + return this->KeyStorage.get(); } std::string GetPublicKey() { - uint8_t buffer[0x1000] = { 0 }; + uint8_t buffer[512] = { 0 }; DWORD length = sizeof(buffer); if (ecc_ansi_x963_export(this->GetKeyPtr(), buffer, &length) == CRYPT_OK) @@ -39,8 +51,28 @@ namespace Utils return ""; } + void Set(std::string pubKeyBuffer) + { + this->Free(); + + if (ecc_ansi_x963_import(reinterpret_cast(pubKeyBuffer.data()), pubKeyBuffer.size(), this->KeyStorage.get()) != CRYPT_OK) + { + ZeroMemory(this->KeyStorage.get(), sizeof(*this->KeyStorage.get())); + } + } + + void Free() + { + if (this->IsValid()) + { + ecc_free(this->KeyStorage.get()); + } + + ZeroMemory(this->KeyStorage.get(), sizeof(*this->KeyStorage.get())); + } + private: - ecc_key KeyStorage; + std::shared_ptr KeyStorage; }; static Key GenerateKey(int bits); @@ -54,20 +86,42 @@ namespace Utils class Key { public: - Key() { ZeroMemory(&this->KeyStorage, sizeof(this->KeyStorage)); }; - Key(rsa_key* key) : Key(*key) {}; - Key(rsa_key key) : KeyStorage(key) {}; - Key(const Key& obj) : KeyStorage(obj.KeyStorage) {}; - - ~Key() {} + Key() : KeyStorage(new rsa_key) + { + ZeroMemory(this->KeyStorage.get(), sizeof(*this->KeyStorage.get())); + }; + Key(rsa_key* key) : Key() { std::memmove(this->KeyStorage.get(), key, sizeof(*key)); }; + Key(rsa_key key) : Key(&key) {}; + ~Key() + { + if (this->KeyStorage.use_count() <= 1) + { + this->Free(); + } + }; rsa_key* GetKeyPtr() { - return &this->KeyStorage; + return this->KeyStorage.get(); + } + + bool IsValid() + { + return (!Utils::MemIsSet(this->KeyStorage.get(), 0, sizeof(*this->KeyStorage.get()))); + } + + void Free() + { + if (this->IsValid()) + { + rsa_free(this->KeyStorage.get()); + } + + ZeroMemory(this->KeyStorage.get(), sizeof(*this->KeyStorage.get())); } private: - rsa_key KeyStorage; + std::shared_ptr KeyStorage; }; static Key GenerateKey(int bits); diff --git a/src/Utils/Utils.cpp b/src/Utils/Utils.cpp index 6e84f061..01c23330 100644 --- a/src/Utils/Utils.cpp +++ b/src/Utils/Utils.cpp @@ -30,6 +30,22 @@ namespace Utils return (strstr(haystack.data(), needle.data()) == (haystack.data() + haystack.size() - needle.size())); } + // Complementary function for memset, which checks if a memory is set + bool MemIsSet(void* mem, char chr, size_t length) + { + char* memArr = reinterpret_cast(mem); + + for (size_t i = 0; i < length; i++) + { + if (memArr[i] != chr) + { + return false; + } + } + + return true; + } + std::vector Explode(const std::string& str, char delim) { std::vector result; diff --git a/src/Utils/Utils.hpp b/src/Utils/Utils.hpp index 9a6a6515..95725575 100644 --- a/src/Utils/Utils.hpp +++ b/src/Utils/Utils.hpp @@ -20,6 +20,8 @@ namespace Utils void WriteFile(std::string file, std::string data); std::string ReadFile(std::string file); + bool MemIsSet(void* mem, char chr, size_t length); + class InfoString { public: