iw4x-client/src/Components/Modules/Download.cpp

428 lines
12 KiB
C++
Raw Normal View History

2016-01-08 21:21:59 -05:00
#include "STDInclude.hpp"
namespace Components
{
Download::Container Download::DataContainer;
Download::Container::DownloadCL* Download::FindClientDownload(int id)
{
for (auto &download : Download::DataContainer.ClientDownloads)
{
if (download.id == id)
{
return &download;
}
}
return nullptr;
}
Download::Container::DownloadSV* Download::FindServerDownload(int id)
{
for (auto &download : Download::DataContainer.ServerDownloads)
{
if (download.id == id)
{
return &download;
}
}
return nullptr;
}
void Download::RemoveClientDownload(int id)
{
2016-01-24 13:58:13 -05:00
for (auto i = Download::DataContainer.ClientDownloads.begin(); i != Download::DataContainer.ClientDownloads.end(); ++i)
2016-01-08 21:21:59 -05:00
{
if (i->id == id)
{
Download::DataContainer.ClientDownloads.erase(i);
return;
}
}
}
void Download::RemoveServerDownload(int id)
{
2016-01-24 13:58:13 -05:00
for (auto i = Download::DataContainer.ServerDownloads.begin(); i != Download::DataContainer.ServerDownloads.end(); ++i)
2016-01-08 21:21:59 -05:00
{
if (i->id == id)
{
Download::DataContainer.ServerDownloads.erase(i);
return;
}
}
}
bool Download::HasSentPacket(Download::Container::DownloadSV* download, int packet)
{
for (auto sentPacket : download->sentParts)
{
if (packet == sentPacket)
{
return true;
}
}
return false;
}
bool Download::HasReceivedPacket(Download::Container::DownloadCL* download, int packet)
{
if (!download->parts.empty())
2016-01-08 21:21:59 -05:00
{
2016-01-24 13:58:13 -05:00
for (auto i = download->parts.begin(); i != download->parts.end(); ++i)
2016-01-08 21:21:59 -05:00
{
if (i->first == packet)
{
return true;
}
}
}
return false;
}
bool Download::HasReceivedAllPackets(Download::Container::DownloadCL* download)
{
2016-01-24 13:58:13 -05:00
for (int i = 0; i < download->maxParts; ++i)
2016-01-08 21:21:59 -05:00
{
if (!Download::HasReceivedPacket(download, i))
{
return false;
}
}
return true;
}
int Download::ReadPacketId(std::string &data)
{
int id = *(int*)data.data();
data = std::string(data.data() + sizeof(int), data.size() - sizeof(int));
return id;
}
// Client handlers
void Download::AckRequest(Network::Address target, std::string data)
{
if (data.size() < sizeof(Download::Container::AckRequest)) return; // Drop invalid packets, if they were important, we'll re-request them later
Download::Container::AckRequest* request = (Download::Container::AckRequest*)data.data();
if (data.size() < (sizeof(Download::Container::AckRequest) + request->length)) return; // Again, drop invalid packets
auto download = Download::FindClientDownload(request->id);
if (download && download->target == target && !download->acknowledged)
{
std::string challenge(data.data() + sizeof(Download::Container::AckRequest), request->length);
download->acknowledged = true;
download->lastPing = Game::Com_Milliseconds();
download->maxParts = request->maxPackets;
2016-02-10 11:18:45 -05:00
std::string packet;
2016-01-08 21:21:59 -05:00
packet.append(reinterpret_cast<char*>(&download->id), sizeof(int));
packet.append(challenge);
2016-02-10 11:18:45 -05:00
Network::SendCommand(target, "dlAckResponse", packet);
2016-01-08 21:21:59 -05:00
}
}
void Download::PacketResponse(Network::Address target, std::string data)
{
2016-01-09 09:30:13 -05:00
//Logger::Print("Packet incoming!\n");
2016-01-08 21:21:59 -05:00
if (data.size() < sizeof(Download::Container::Packet)) return; // Drop invalid packets, if they were important, we'll re-request them later
Download::Container::Packet* packet = (Download::Container::Packet*)data.data();
2016-01-09 09:30:13 -05:00
//Logger::Print("Reading data!\n");
2016-01-08 21:21:59 -05:00
if (data.size() < (sizeof(Download::Container::Packet) + packet->length)) return; // Again, drop invalid packets
2016-01-09 09:30:13 -05:00
//Logger::Print("Finding corresponding download!\n");
2016-01-08 21:21:59 -05:00
auto download = Download::FindClientDownload(packet->id);
if (download && download->target == target)
{
2016-01-09 09:30:13 -05:00
//Logger::Print("Parsing packet!\n");
2016-01-08 21:21:59 -05:00
download->lastPing = Game::Com_Milliseconds();
std::string packetData(data.data() + sizeof(Download::Container::Packet), packet->length);
if (packet->hash == Utils::OneAtATime(packetData.data(), packetData.size()))
{
2016-01-09 09:30:13 -05:00
//Logger::Print("Packet added!\n");
2016-01-08 21:21:59 -05:00
download->parts[packet->partId] = packetData;
if (Download::HasReceivedAllPackets(download))
{
download->successCallback(download->id, Download::AssembleBuffer(download));
Download::RemoveClientDownload(download->id);
}
}
2016-01-09 09:30:13 -05:00
else
{
Logger::Print("Hash invalid!\n");
}
2016-01-08 21:21:59 -05:00
}
}
// Server handlers
void Download::AckResponse(Network::Address target, std::string data)
{
int id = Download::ReadPacketId(data);
std::string challenge = Utils::ParseChallenge(data); // TODO: Maybe optimize this to ensure length matches
auto download = Download::FindServerDownload(id);
if (download && download->target == target)
{
if (download->challenge != challenge)
{
Logger::Print("Invalid download challenge!\n");
Download::RemoveServerDownload(id);
}
else
{
download->lastPing = Game::Com_Milliseconds();
download->acknowledged = true;
2016-01-09 09:30:13 -05:00
Logger::Print("Client acknowledged!\n");
2016-01-08 21:21:59 -05:00
}
}
}
void Download::MissingRequest(Network::Address target, std::string data)
{
int id = Download::ReadPacketId(data);
auto download = Download::FindServerDownload(id);
if (download && download->target == target)
{
while ((data.size() % 4) >= 4)
{
Download::MarkPacketAsDirty(download, *reinterpret_cast<int*>(const_cast<char*>(data.data())));
2016-01-08 21:21:59 -05:00
data = data.substr(4);
}
}
}
void Download::DownloadRequest(Network::Address target, std::string data)
{
int id = Download::ReadPacketId(data);
Download::Container::DownloadSV download;
download.id = id;
download.target = target;
download.acknowledged = false;
download.startTime = Game::Com_Milliseconds();
download.lastPing = Game::Com_Milliseconds();
download.maxParts = 0;
2016-01-24 13:58:13 -05:00
for (int i = 0; i < 1000000; ++i)
2016-01-08 21:21:59 -05:00
{
2016-01-09 09:30:13 -05:00
download.buffer.append("1234567890");
2016-01-08 21:21:59 -05:00
}
download.maxParts = download.buffer.size() / PACKET_SIZE;
if (download.buffer.size() % PACKET_SIZE) download.maxParts++;
2016-02-08 08:27:15 -05:00
download.challenge = Utils::VA("%X", Utils::Cryptography::Rand::GenerateInt());
2016-01-08 21:21:59 -05:00
Download::Container::AckRequest request;
request.id = id;
request.maxPackets = download.maxParts;
request.length = download.challenge.size();
2016-02-10 11:18:45 -05:00
std::string packet;
2016-01-08 21:21:59 -05:00
packet.append(reinterpret_cast<char*>(&request), sizeof(request));
packet.append(download.challenge);
Download::DataContainer.ServerDownloads.push_back(download);
2016-02-10 11:18:45 -05:00
Network::SendCommand(target, "dlAckRequest", packet);
2016-01-08 21:21:59 -05:00
}
std::string Download::AssembleBuffer(Download::Container::DownloadCL* download)
{
std::string buffer;
2016-01-24 13:58:13 -05:00
for (int i = 0; i < download->maxParts; ++i)
2016-01-08 21:21:59 -05:00
{
if (!Download::HasReceivedPacket(download, i)) return "";
buffer.append(download->parts[i]);
}
return buffer;
}
void Download::RequestMissingPackets(Download::Container::DownloadCL* download, std::vector<int> packets)
{
if (!packets.empty())
2016-01-08 21:21:59 -05:00
{
2016-01-09 09:30:13 -05:00
download->lastPing = Game::Com_Milliseconds();
2016-02-10 11:18:45 -05:00
std::string data;
2016-01-08 21:21:59 -05:00
data.append(reinterpret_cast<char*>(&download->id), sizeof(int));
for (auto &packet : packets)
{
data.append(reinterpret_cast<char*>(&packet), sizeof(int));
2016-01-08 21:21:59 -05:00
}
2016-02-10 11:18:45 -05:00
Network::SendCommand(download->target, "dlMissRequest", data);
2016-01-08 21:21:59 -05:00
}
}
void Download::MarkPacketAsDirty(Download::Container::DownloadSV* download, int packet)
{
if (!download->sentParts.empty())
2016-01-08 21:21:59 -05:00
{
2016-01-24 13:58:13 -05:00
for (auto i = download->sentParts.begin(); i != download->sentParts.end(); ++i)
2016-01-08 21:21:59 -05:00
{
if (*i == packet)
{
download->sentParts.erase(i);
i = download->sentParts.begin();
}
}
}
}
void Download::SendPacket(Download::Container::DownloadSV* download, int packet)
{
2016-01-09 09:30:13 -05:00
if (!download || packet >= download->maxParts) return;
2016-01-08 21:21:59 -05:00
download->lastPing = Game::Com_Milliseconds();
download->sentParts.push_back(packet);
Download::Container::Packet packetContainer;
packetContainer.id = download->id;
packetContainer.partId = packet;
int size = ((packet + 1) == download->maxParts ? (download->buffer.size() % PACKET_SIZE) : PACKET_SIZE);
size = (size == 0 ? PACKET_SIZE : size); // If remaining data equals packet PACKET_SIZE, size would be 0, so adjust it.
std::string data(download->buffer.data() + (packet * PACKET_SIZE), size);
packetContainer.length = data.size();
packetContainer.hash = Utils::OneAtATime(data.data(), data.size());
std::string response = "dlPacketResponse\n";
response.append(reinterpret_cast<char*>(&packetContainer), sizeof(packetContainer));
2016-01-08 21:21:59 -05:00
response.append(data);
2016-02-10 11:18:45 -05:00
Network::SendCommand(download->target, "dlPacketResponse", response);
2016-01-08 21:21:59 -05:00
}
void Download::Frame()
{
if (!Download::DataContainer.ClientDownloads.empty())
2016-01-08 21:21:59 -05:00
{
2016-01-24 13:58:13 -05:00
for (auto i = Download::DataContainer.ClientDownloads.begin(); i != Download::DataContainer.ClientDownloads.end(); ++i)
2016-01-08 21:21:59 -05:00
{
if ((Game::Com_Milliseconds() - i->lastPing) > (DOWNLOAD_TIMEOUT * 2))
{
i->failureCallback(i->id);
Download::DataContainer.ClientDownloads.erase(i);
return;
}
// Request missing parts
2016-01-09 09:30:13 -05:00
if (i->acknowledged && (Game::Com_Milliseconds() - i->lastPing) > DOWNLOAD_TIMEOUT)
2016-01-08 21:21:59 -05:00
{
std::vector<int> missingPackets;
2016-01-24 13:58:13 -05:00
for (int j = 0; j < i->maxParts; ++j)
2016-01-08 21:21:59 -05:00
{
2016-02-15 14:32:41 -05:00
if (!Download::HasReceivedPacket(&(*i), j))
2016-01-08 21:21:59 -05:00
{
missingPackets.push_back(j);
}
}
2016-01-09 09:30:13 -05:00
2016-02-15 14:32:41 -05:00
Download::RequestMissingPackets(&(*i), missingPackets);
2016-01-08 21:21:59 -05:00
}
}
}
if (!Download::DataContainer.ServerDownloads.empty())
2016-01-08 21:21:59 -05:00
{
2016-01-24 13:58:13 -05:00
for (auto i = Download::DataContainer.ServerDownloads.begin(); i != Download::DataContainer.ServerDownloads.end(); ++i)
2016-01-08 21:21:59 -05:00
{
if ((Game::Com_Milliseconds() - i->lastPing) > (DOWNLOAD_TIMEOUT * 3))
{
Download::DataContainer.ServerDownloads.erase(i);
return;
}
int packets = 0;
2016-01-24 13:58:13 -05:00
for (int j = 0; j < i->maxParts && packets <= FRAME_PACKET_LIMIT && i->acknowledged; ++j)
2016-01-08 21:21:59 -05:00
{
2016-02-15 14:32:41 -05:00
if (!Download::HasSentPacket(&(*i), j))
2016-01-08 21:21:59 -05:00
{
2016-01-09 09:30:13 -05:00
//Logger::Print("Sending packet...\n");
2016-02-15 14:32:41 -05:00
Download::SendPacket(&(*i), j);
2016-01-08 21:21:59 -05:00
packets++;
}
}
}
}
}
int Download::Get(Network::Address target, std::string file, std::function<void(int, std::string)> successCallback, std::function<void(int)> failureCallback)
{
Download::Container::DownloadCL download;
download.id = Game::Com_Milliseconds();
download.target = target;
download.acknowledged = false;
download.startTime = Game::Com_Milliseconds();
download.lastPing = Game::Com_Milliseconds();
download.maxParts = 0;
download.failureCallback = failureCallback;
download.successCallback = successCallback;
Download::DataContainer.ClientDownloads.push_back(download);
std::string response = "dlRequest\n";
response.append(reinterpret_cast<char*>(&download.id), sizeof(int));
2016-01-08 21:21:59 -05:00
response.append(file);
2016-02-10 11:18:45 -05:00
Network::SendCommand(target, "dlRequest", response);
2016-01-08 21:21:59 -05:00
return download.id;
}
Download::Download()
{
#ifdef ENABLE_EXPERIMENTAL_UDP_DOWNLOAD
2016-01-08 21:21:59 -05:00
// Frame handlers
2016-03-01 07:37:51 -05:00
QuickPatch::OnFrame(Download::Frame);
2016-01-08 21:21:59 -05:00
// Register client handlers
Network::Handle("dlAckRequest", Download::AckRequest);
Network::Handle("dlPacketResponse", Download::PacketResponse);
// Register server handlers
Network::Handle("dlAckResponse", Download::AckResponse);
Network::Handle("dlMissRequest", Download::MissingRequest);
Network::Handle("dlAckResponse", Download::AckResponse);
Network::Handle("dlRequest", Download::DownloadRequest);
Command::Add("zob", [] (Command::Params params)
{
Logger::Print("Requesting!\n");
Download::Get(Network::Address("192.168.0.23:28960"), "test", [] (int id, std::string data)
{
2016-01-09 09:30:13 -05:00
Logger::Print("Download succeeded %d!\n", Game::Com_Milliseconds() - (Download::FindClientDownload(id)->startTime));
2016-01-08 21:21:59 -05:00
}, [] (int id)
{
Logger::Print("Download failed!\n");
});
});
#endif
2016-01-08 21:21:59 -05:00
}
Download::~Download()
{
Download::DataContainer.ServerDownloads.clear();
Download::DataContainer.ClientDownloads.clear();
}
}