Fix cryptography memory leaks :P

This commit is contained in:
momo5502 2016-02-08 18:43:31 +01:00
parent 4d36a0b9ed
commit 87c239a1dc
6 changed files with 100 additions and 19 deletions

View File

@ -2,6 +2,8 @@
namespace Components namespace Components
{ {
Utils::Cryptography::ECDSA::Key Node::SignatureKey;
std::vector<Node::NodeEntry> Node::Nodes; std::vector<Node::NodeEntry> Node::Nodes;
std::vector<Node::DediEntry> Node::Dedis; std::vector<Node::DediEntry> Node::Dedis;
@ -396,6 +398,9 @@ namespace Components
// ZoneBuilder doesn't require node stuff // ZoneBuilder doesn't require node stuff
if (ZoneBuilder::IsEnabled()) return; if (ZoneBuilder::IsEnabled()) return;
// Generate our ECDSA key
Node::SignatureKey = Utils::Cryptography::ECDSA::GenerateKey(512);
Dvar::OnInit([] () Dvar::OnInit([] ()
{ {
Node::Dedis.clear(); Node::Dedis.clear();
@ -570,6 +575,8 @@ namespace Components
Node::~Node() Node::~Node()
{ {
Node::SignatureKey.Free();
Node::StoreNodes(true); Node::StoreNodes(true);
Node::Nodes.clear(); Node::Nodes.clear();
Node::Dedis.clear(); Node::Dedis.clear();

View File

@ -44,6 +44,7 @@ namespace Components
struct NodeEntry struct NodeEntry
{ {
Network::Address address; Network::Address address;
Utils::Cryptography::ECDSA::Key publicKey;
EntryState state; EntryState state;
int lastTime; // Last time we heard anything from the server itself int lastTime; // Last time we heard anything from the server itself
int lastHeartbeat; // Last time we got a heartbeat from it int lastHeartbeat; // Last time we got a heartbeat from it
@ -84,6 +85,8 @@ namespace Components
}; };
#pragma pack(pop) #pragma pack(pop)
static Utils::Cryptography::ECDSA::Key SignatureKey;
static std::vector<NodeEntry> Nodes; static std::vector<NodeEntry> Nodes;
static std::vector<DediEntry> Dedis; static std::vector<DediEntry> Dedis;

View File

@ -30,7 +30,6 @@ namespace Utils
ECDSA::Key key; ECDSA::Key key;
register_prng(&sprng_desc); register_prng(&sprng_desc);
register_hash(&sha1_desc);
ltc_mp = ltm_desc; ltc_mp = ltm_desc;
@ -41,7 +40,7 @@ namespace Utils
std::string ECDSA::SignMessage(Key key, std::string message) std::string ECDSA::SignMessage(Key key, std::string message)
{ {
uint8_t buffer[0x200]; // Default size is 512 uint8_t buffer[512];
DWORD length = sizeof(buffer); DWORD length = sizeof(buffer);
register_prng(&sprng_desc); register_prng(&sprng_desc);
@ -81,7 +80,7 @@ namespace Utils
std::string RSA::SignMessage(RSA::Key key, std::string message) std::string RSA::SignMessage(RSA::Key key, std::string message)
{ {
uint8_t buffer[0x200]; // Default size is 512 uint8_t buffer[512];
DWORD length = sizeof(buffer); DWORD length = sizeof(buffer);
register_prng(&sprng_desc); register_prng(&sprng_desc);

View File

@ -14,21 +14,33 @@ namespace Utils
class Key class Key
{ {
public: public:
Key() { ZeroMemory(&this->KeyStorage, sizeof(this->KeyStorage)); }; Key() : KeyStorage(new ecc_key)
Key(ecc_key* key) : Key(*key) {}; {
Key(ecc_key key) : KeyStorage(key) {}; ZeroMemory(this->KeyStorage.get(), sizeof(*this->KeyStorage.get()));
Key(const Key& obj) : KeyStorage(obj.KeyStorage) {}; };
Key(ecc_key* key) : Key() { std::memmove(this->KeyStorage.get(), key, sizeof(*key)); };
Key(ecc_key key) : Key(&key) {};
~Key()
{
if (this->KeyStorage.use_count() <= 1)
{
this->Free();
}
};
~Key() {} bool IsValid()
{
return (!Utils::MemIsSet(this->KeyStorage.get(), 0, sizeof(*this->KeyStorage.get())));
}
ecc_key* GetKeyPtr() ecc_key* GetKeyPtr()
{ {
return &this->KeyStorage; return this->KeyStorage.get();
} }
std::string GetPublicKey() std::string GetPublicKey()
{ {
uint8_t buffer[0x1000] = { 0 }; uint8_t buffer[512] = { 0 };
DWORD length = sizeof(buffer); DWORD length = sizeof(buffer);
if (ecc_ansi_x963_export(this->GetKeyPtr(), buffer, &length) == CRYPT_OK) if (ecc_ansi_x963_export(this->GetKeyPtr(), buffer, &length) == CRYPT_OK)
@ -39,8 +51,28 @@ namespace Utils
return ""; return "";
} }
void Set(std::string pubKeyBuffer)
{
this->Free();
if (ecc_ansi_x963_import(reinterpret_cast<const uint8_t*>(pubKeyBuffer.data()), pubKeyBuffer.size(), this->KeyStorage.get()) != CRYPT_OK)
{
ZeroMemory(this->KeyStorage.get(), sizeof(*this->KeyStorage.get()));
}
}
void Free()
{
if (this->IsValid())
{
ecc_free(this->KeyStorage.get());
}
ZeroMemory(this->KeyStorage.get(), sizeof(*this->KeyStorage.get()));
}
private: private:
ecc_key KeyStorage; std::shared_ptr<ecc_key> KeyStorage;
}; };
static Key GenerateKey(int bits); static Key GenerateKey(int bits);
@ -54,20 +86,42 @@ namespace Utils
class Key class Key
{ {
public: public:
Key() { ZeroMemory(&this->KeyStorage, sizeof(this->KeyStorage)); }; Key() : KeyStorage(new rsa_key)
Key(rsa_key* key) : Key(*key) {}; {
Key(rsa_key key) : KeyStorage(key) {}; ZeroMemory(this->KeyStorage.get(), sizeof(*this->KeyStorage.get()));
Key(const Key& obj) : KeyStorage(obj.KeyStorage) {}; };
Key(rsa_key* key) : Key() { std::memmove(this->KeyStorage.get(), key, sizeof(*key)); };
~Key() {} Key(rsa_key key) : Key(&key) {};
~Key()
{
if (this->KeyStorage.use_count() <= 1)
{
this->Free();
}
};
rsa_key* GetKeyPtr() rsa_key* GetKeyPtr()
{ {
return &this->KeyStorage; return this->KeyStorage.get();
}
bool IsValid()
{
return (!Utils::MemIsSet(this->KeyStorage.get(), 0, sizeof(*this->KeyStorage.get())));
}
void Free()
{
if (this->IsValid())
{
rsa_free(this->KeyStorage.get());
}
ZeroMemory(this->KeyStorage.get(), sizeof(*this->KeyStorage.get()));
} }
private: private:
rsa_key KeyStorage; std::shared_ptr<rsa_key> KeyStorage;
}; };
static Key GenerateKey(int bits); static Key GenerateKey(int bits);

View File

@ -30,6 +30,22 @@ namespace Utils
return (strstr(haystack.data(), needle.data()) == (haystack.data() + haystack.size() - needle.size())); return (strstr(haystack.data(), needle.data()) == (haystack.data() + haystack.size() - needle.size()));
} }
// Complementary function for memset, which checks if a memory is set
bool MemIsSet(void* mem, char chr, size_t length)
{
char* memArr = reinterpret_cast<char*>(mem);
for (size_t i = 0; i < length; i++)
{
if (memArr[i] != chr)
{
return false;
}
}
return true;
}
std::vector<std::string> Explode(const std::string& str, char delim) std::vector<std::string> Explode(const std::string& str, char delim)
{ {
std::vector<std::string> result; std::vector<std::string> result;

View File

@ -20,6 +20,8 @@ namespace Utils
void WriteFile(std::string file, std::string data); void WriteFile(std::string file, std::string data);
std::string ReadFile(std::string file); std::string ReadFile(std::string file);
bool MemIsSet(void* mem, char chr, size_t length);
class InfoString class InfoString
{ {
public: public: