Start implementing download protocol.

This commit is contained in:
momo5502 2016-01-09 03:21:59 +01:00
parent 8f9b703dbb
commit 5624390e67
6 changed files with 519 additions and 2 deletions

View File

@ -21,6 +21,7 @@ namespace Components
Loader::Register(new Console());
Loader::Register(new IPCPipe());
Loader::Register(new Network());
Loader::Register(new Download());
Loader::Register(new Playlist());
Loader::Register(new RawFiles());
Loader::Register(new Renderer());

View File

@ -33,6 +33,7 @@ namespace Components
#include "Modules\IPCPipe.hpp"
#include "Modules\Network.hpp"
#include "Modules\Party.hpp" // Destroys the order, but requires network classes :D
#include "Modules\Download.hpp"
#include "Modules\Playlist.hpp"
#include "Modules\RawFiles.hpp"
#include "Modules\Renderer.hpp"

View File

@ -0,0 +1,422 @@
#include "STDInclude.hpp"
namespace Components
{
Download::Container Download::DataContainer;
Download::Container::DownloadCL* Download::FindClientDownload(int id)
{
for (auto &download : Download::DataContainer.ClientDownloads)
{
if (download.id == id)
{
return &download;
}
}
return nullptr;
}
Download::Container::DownloadSV* Download::FindServerDownload(int id)
{
for (auto &download : Download::DataContainer.ServerDownloads)
{
if (download.id == id)
{
return &download;
}
}
return nullptr;
}
void Download::RemoveClientDownload(int id)
{
for (auto i = Download::DataContainer.ClientDownloads.begin(); i != Download::DataContainer.ClientDownloads.end(); i++)
{
if (i->id == id)
{
Download::DataContainer.ClientDownloads.erase(i);
return;
}
}
}
void Download::RemoveServerDownload(int id)
{
for (auto i = Download::DataContainer.ServerDownloads.begin(); i != Download::DataContainer.ServerDownloads.end(); i++)
{
if (i->id == id)
{
Download::DataContainer.ServerDownloads.erase(i);
return;
}
}
}
bool Download::HasSentPacket(Download::Container::DownloadSV* download, int packet)
{
for (auto sentPacket : download->sentParts)
{
if (packet == sentPacket)
{
return true;
}
}
return false;
}
bool Download::HasReceivedPacket(Download::Container::DownloadCL* download, int packet)
{
if (download->parts.size())
{
for (auto i = download->parts.begin(); i != download->parts.end(); i++)
{
if (i->first == packet)
{
return true;
}
}
}
return false;
}
bool Download::HasReceivedAllPackets(Download::Container::DownloadCL* download)
{
for (int i = 0; i < download->maxParts; i++)
{
if (!Download::HasReceivedPacket(download, i))
{
return false;
}
}
return true;
}
int Download::ReadPacketId(std::string &data)
{
int id = *(int*)data.data();
data = std::string(data.data() + sizeof(int), data.size() - sizeof(int));
return id;
}
// Client handlers
void Download::AckRequest(Network::Address target, std::string data)
{
if (data.size() < sizeof(Download::Container::AckRequest)) return; // Drop invalid packets, if they were important, we'll re-request them later
Download::Container::AckRequest* request = (Download::Container::AckRequest*)data.data();
if (data.size() < (sizeof(Download::Container::AckRequest) + request->length)) return; // Again, drop invalid packets
auto download = Download::FindClientDownload(request->id);
if (download && download->target == target && !download->acknowledged)
{
std::string challenge(data.data() + sizeof(Download::Container::AckRequest), request->length);
download->acknowledged = true;
download->lastPing = Game::Com_Milliseconds();
download->maxParts = request->maxPackets;
std::string packet = "dlAckResponse\n";
packet.append(reinterpret_cast<char*>(&download->id), sizeof(int));
packet.append(challenge);
Network::SendRaw(target, packet);
}
}
void Download::PacketResponse(Network::Address target, std::string data)
{
if (data.size() < sizeof(Download::Container::Packet)) return; // Drop invalid packets, if they were important, we'll re-request them later
Download::Container::Packet* packet = (Download::Container::Packet*)data.data();
if (data.size() < (sizeof(Download::Container::Packet) + packet->length)) return; // Again, drop invalid packets
auto download = Download::FindClientDownload(packet->id);
if (download && download->target == target)
{
download->lastPing = Game::Com_Milliseconds();
std::string packetData(data.data() + sizeof(Download::Container::Packet), packet->length);
if (packet->hash == Utils::OneAtATime(packetData.data(), packetData.size()))
{
download->parts[packet->partId] = packetData;
if (Download::HasReceivedAllPackets(download))
{
download->successCallback(download->id, Download::AssembleBuffer(download));
Download::RemoveClientDownload(download->id);
}
}
}
}
// Server handlers
void Download::AckResponse(Network::Address target, std::string data)
{
int id = Download::ReadPacketId(data);
std::string challenge = Utils::ParseChallenge(data); // TODO: Maybe optimize this to ensure length matches
auto download = Download::FindServerDownload(id);
if (download && download->target == target)
{
if (download->challenge != challenge)
{
Logger::Print("Invalid download challenge!\n");
Download::RemoveServerDownload(id);
}
else
{
download->lastPing = Game::Com_Milliseconds();
download->acknowledged = true;
}
}
}
void Download::MissingRequest(Network::Address target, std::string data)
{
int id = Download::ReadPacketId(data);
auto download = Download::FindServerDownload(id);
if (download && download->target == target)
{
while ((data.size() % 4) >= 4)
{
Download::MarkPacketAsDirty(download, *(int*)data.data());
data = data.substr(4);
}
}
}
void Download::DownloadRequest(Network::Address target, std::string data)
{
int id = Download::ReadPacketId(data);
Download::Container::DownloadSV download;
download.id = id;
download.target = target;
download.acknowledged = false;
download.startTime = Game::Com_Milliseconds();
download.lastPing = Game::Com_Milliseconds();
download.maxParts = 0;
// Generate random 40kb buffer
for (int i = 0; i < 10000; i++)
{
download.buffer.append(Utils::VA("%i", i));
}
download.maxParts = download.buffer.size() / PACKET_SIZE;
if (download.buffer.size() % PACKET_SIZE) download.maxParts++;
download.challenge = Utils::VA("%X", Game::Com_Milliseconds());
Download::Container::AckRequest request;
request.id = id;
request.maxPackets = download.maxParts;
request.length = download.challenge.size();
std::string packet = "dlAckRequest\n";
packet.append(reinterpret_cast<char*>(&request), sizeof(request));
packet.append(download.challenge);
Download::DataContainer.ServerDownloads.push_back(download);
Network::SendRaw(target, packet);
}
std::string Download::AssembleBuffer(Download::Container::DownloadCL* download)
{
std::string buffer;
for (int i = 0; i < download->maxParts; i++)
{
if (!Download::HasReceivedPacket(download, i)) return "";
buffer.append(download->parts[i]);
}
return buffer;
}
void Download::RequestMissingPackets(Download::Container::DownloadCL* download, std::vector<int> packets)
{
if (packets.size())
{
std::string data = "dlMissRequest\n";
data.append(reinterpret_cast<char*>(&download->id), sizeof(int));
for (auto &packet : packets)
{
data.append((char*)&packet, sizeof(int));
}
Network::SendRaw(download->target, data);
}
}
void Download::MarkPacketAsDirty(Download::Container::DownloadSV* download, int packet)
{
if (download->sentParts.size())
{
for (auto i = download->sentParts.begin(); i != download->sentParts.end(); i++)
{
if (*i == packet)
{
download->sentParts.erase(i);
i = download->sentParts.begin();
}
}
}
}
void Download::SendPacket(Download::Container::DownloadSV* download, int packet)
{
if (!download || packet < download->maxParts) return;
download->lastPing = Game::Com_Milliseconds();
download->sentParts.push_back(packet);
Download::Container::Packet packetContainer;
packetContainer.id = download->id;
packetContainer.partId = packet;
int size = ((packet + 1) == download->maxParts ? (download->buffer.size() % PACKET_SIZE) : PACKET_SIZE);
size = (size == 0 ? PACKET_SIZE : size); // If remaining data equals packet PACKET_SIZE, size would be 0, so adjust it.
std::string data(download->buffer.data() + (packet * PACKET_SIZE), size);
packetContainer.length = data.size();
packetContainer.hash = Utils::OneAtATime(data.data(), data.size());
std::string response = "dlPacketResponse\n";
response.append((char*)&packetContainer, sizeof(packetContainer));
response.append(data);
Network::SendRaw(download->target, response);
}
void Download::Frame()
{
if (Download::DataContainer.ClientDownloads.size())
{
for (auto i = Download::DataContainer.ClientDownloads.begin(); i != Download::DataContainer.ClientDownloads.end(); i++)
{
if ((Game::Com_Milliseconds() - i->lastPing) > (DOWNLOAD_TIMEOUT * 2))
{
i->failureCallback(i->id);
Download::DataContainer.ClientDownloads.erase(i);
return;
}
// Request missing parts
if ((Game::Com_Milliseconds() - i->lastPing) > DOWNLOAD_TIMEOUT)
{
std::vector<int> missingPackets;
for (int j = 0; j < i->maxParts; j++)
{
if (!Download::HasReceivedPacket(&*i, j))
{
missingPackets.push_back(j);
}
}
}
}
}
if (Download::DataContainer.ServerDownloads.size())
{
for (auto i = Download::DataContainer.ServerDownloads.begin(); i != Download::DataContainer.ServerDownloads.end(); i++)
{
if ((Game::Com_Milliseconds() - i->lastPing) > (DOWNLOAD_TIMEOUT * 3))
{
Download::DataContainer.ServerDownloads.erase(i);
return;
}
int packets = 0;
for (int j = 0; j < i->maxParts && packets <= FRAME_PACKET_LIMIT; j++)
{
if (!Download::HasSentPacket(&*i, j))
{
Download::SendPacket(&*i, j);
packets++;
}
}
}
}
}
int Download::Get(Network::Address target, std::string file, std::function<void(int, std::string)> successCallback, std::function<void(int)> failureCallback)
{
Download::Container::DownloadCL download;
download.id = Game::Com_Milliseconds();
download.target = target;
download.acknowledged = false;
download.startTime = Game::Com_Milliseconds();
download.lastPing = Game::Com_Milliseconds();
download.maxParts = 0;
download.failureCallback = failureCallback;
download.successCallback = successCallback;
Download::DataContainer.ClientDownloads.push_back(download);
std::string response = "dlRequest\n";
response.append((char*)&download.id, sizeof(int));
response.append(file);
Network::SendRaw(target, response);
return download.id;
}
Download::Download()
{
return;
// Frame handlers
if (Dedicated::IsDedicated())
{
Dedicated::OnFrame(Download::Frame);
}
else
{
Renderer::OnFrame(Download::Frame);
}
// Register client handlers
Network::Handle("dlAckRequest", Download::AckRequest);
Network::Handle("dlPacketResponse", Download::PacketResponse);
// Register server handlers
Network::Handle("dlAckResponse", Download::AckResponse);
Network::Handle("dlMissRequest", Download::MissingRequest);
Network::Handle("dlAckResponse", Download::AckResponse);
Network::Handle("dlRequest", Download::DownloadRequest);
Command::Add("zob", [] (Command::Params params)
{
Logger::Print("Requesting!\n");
Download::Get(Network::Address("192.168.0.23:28960"), "test", [] (int id, std::string data)
{
Logger::Print("Download succeeded!\n");
}, [] (int id)
{
Logger::Print("Download failed!\n");
});
});
}
Download::~Download()
{
Download::DataContainer.ServerDownloads.clear();
Download::DataContainer.ClientDownloads.clear();
}
}

View File

@ -0,0 +1,93 @@
#define FRAME_PACKET_LIMIT 10
#define DOWNLOAD_TIMEOUT 2500
#define PACKET_SIZE 1000
namespace Components
{
class Download : public Component
{
public:
Download();
~Download();
const char* GetName() { return "Download"; };
static int Get(Network::Address target, std::string file, std::function<void(int, std::string)> successCallback, std::function<void(int)> failureCallback);
private:
struct Container
{
struct DownloadCL
{
int id;
int startTime;
int lastPing;
bool acknowledged;
int maxParts;
Network::Address target;
std::map<int, std::string> parts;
std::function<void(int)> failureCallback;
std::function<void(int, std::string)> successCallback;
};
struct DownloadSV
{
int id;
int startTime;
int lastPing;
bool acknowledged;
std::vector<int> sentParts;
int maxParts;
Network::Address target;
std::string challenge;
std::string buffer;
};
struct Packet
{
int id;
int partId;
int length;
unsigned int hash;
};
struct AckRequest
{
int id;
int maxPackets;
int length;
};
std::vector<DownloadCL> ClientDownloads;
std::vector<DownloadSV> ServerDownloads;
};
static void Frame();
static Container DataContainer;
// Client handlers
static void AckRequest(Network::Address target, std::string data);
static void PacketResponse(Network::Address target, std::string data);
// Server handlers
static void AckResponse(Network::Address target, std::string data);
static void MissingRequest(Network::Address target, std::string data);
static void DownloadRequest(Network::Address target, std::string data);
// Helper functions
static Container::DownloadCL* FindClientDownload(int id);
static Container::DownloadSV* FindServerDownload(int id);
static void RemoveClientDownload(int id);
static void RemoveServerDownload(int id);
static bool HasSentPacket(Container::DownloadSV* download, int packet);
static bool HasReceivedPacket(Container::DownloadCL* download, int packet);
static bool HasReceivedAllPackets(Container::DownloadCL* download);
static std::string AssembleBuffer(Container::DownloadCL* download);
static void RequestMissingPackets(Container::DownloadCL* download, std::vector<int> packets);
static void MarkPacketAsDirty(Container::DownloadSV* download, int packet);
static void SendPacket(Container::DownloadSV* download, int packet);
static int ReadPacketId(std::string &data);
};
}

View File

@ -62,7 +62,7 @@ namespace Utils
}
}
unsigned int OneAtATime(char *key, size_t len)
unsigned int OneAtATime(const char *key, size_t len)
{
unsigned int hash, i;
for (hash = i = 0; i < len; ++i)

View File

@ -5,7 +5,7 @@ namespace Utils
bool EndsWith(const char* heystack, const char* needle);
std::vector<std::string> Explode(const std::string& str, char delim);
void Replace(std::string &string, std::string find, std::string replace);
unsigned int OneAtATime(char *key, size_t len);
unsigned int OneAtATime(const char *key, size_t len);
std::string &LTrim(std::string &s);
std::string &RTrim(std::string &s);
std::string &Trim(std::string &s);