diff --git a/src/client/component/updater.cpp b/src/client/component/updater.cpp new file mode 100644 index 0000000..920aea1 --- /dev/null +++ b/src/client/component/updater.cpp @@ -0,0 +1,217 @@ +#include +#include "loader/component_loader.hpp" +#include "splash.hpp" + +#include + +#include +#include +#include +#include + +#define VERSION_URL "https://nightly.link/momo5502/boiii/workflows/build/" GIT_BRANCH "/Version.zip" +#define BINARY_URL "https://nightly.link/momo5502/boiii/workflows/build/" GIT_BRANCH "/Release%20Binary.zip" + +namespace updater +{ + namespace + { + std::string get_version_zip() + { + const auto version_zip = utils::http::get_data(VERSION_URL); + if (!version_zip || version_zip->empty()) + { + throw std::runtime_error("Invalid version data"); + } + + return *version_zip; + } + + std::string get_version() + { + const auto zip = get_version_zip(); + auto res = utils::compression::zip::extract(zip); + return res["version.txt"]; + } + + bool requires_update() + { + return get_version() != GIT_HASH; + } + + std::string get_self_file() + { + const auto self = utils::nt::library::get_by_address(get_self_file); + return self.get_path(); + } + + std::string get_leftover_file() + { + return get_self_file() + ".old"; + } + + std::string download_update(utils::progress_ui& progress_ui) + { + const auto data = utils::http::get_data( + BINARY_URL, {}, [&progress_ui](const size_t total, const size_t current) + { + if (progress_ui.is_cancelled()) + { + throw std::runtime_error("Cancelled"); + } + + if (total > 0) + { + progress_ui.set_progress(current, total); + } + }); + + if (!data) + { + throw std::runtime_error("Invalid binary"); + } + + return *data; + } + + void activate_update() + { + utils::nt::relaunch_self(); + TerminateProcess(GetCurrentProcess(), 0); + } + + std::string get_binary(const std::string& data) + { + auto res = utils::compression::zip::extract(data); + if (res.size() == 1) + { + for (auto& file : res) + { + return std::move(file.second); + } + } + + throw std::runtime_error("Invalid data"); + } + + void cleanup_update() + { + const auto leftover_file = get_leftover_file(); + for (size_t i = 0; i < 3; ++i) + { + if (utils::io::remove_file(leftover_file)) + { + break; + } + + std::this_thread::sleep_for(1s); + } + } + + void perform_update() + { + utils::progress_ui progress_ui{}; + progress_ui.set_title("Updating BOIII"); + progress_ui.set_line(1, "Downloading update..."); + progress_ui.show(true); + + const auto update_data = download_update(progress_ui); + + if (progress_ui.is_cancelled()) + { + return; + } + + // Is it good to add artificial sleeps? + // Makes the ui nice, for sure. + std::this_thread::sleep_for(2s); + + progress_ui.set_line(1, "Installing update..."); + progress_ui.set_progress(1, 1); + + const auto self_file = get_self_file(); + const auto leftover_file = get_leftover_file(); + + const auto binary = get_binary(update_data); + + cleanup_update(); + utils::io::move_file(self_file, leftover_file); + utils::io::write_file(self_file, binary); + + std::this_thread::sleep_for(2s); + } + } + + class component final : public component_interface + { + public: + component() + { + cleanup_update(); + + this->update_thread_ = std::thread([this] + { + this->update(); + }); + } + + ~component() override + { + if (this->update_thread_.joinable()) + { + this->update_thread_.detach(); + } + } + + void pre_start() override + { + join(); + } + + void pre_destroy() override + { + join(); + } + + void post_unpack() override + { + join(); + } + + int priority() override + { + return 999; + } + + private: + std::thread update_thread_{}; + + void join() + { + if (this->update_thread_.joinable()) + { + this->update_thread_.join(); + } + } + + void update() + { + try + { + if (requires_update()) + { + splash::hide(); + perform_update(); + activate_update(); + } + } + catch (...) + { + } + } + }; +} + +#if !defined(DEBUG) && defined(CI) +REGISTER_COMPONENT(updater::component) +#endif diff --git a/src/common/utils/com.cpp b/src/common/utils/com.cpp index 83cfc47..efb1fe0 100644 --- a/src/common/utils/com.cpp +++ b/src/common/utils/com.cpp @@ -12,45 +12,49 @@ namespace utils::com { namespace { - [[maybe_unused]] class _ + void initialize_com() { - public: - _() + static struct x { - if(FAILED(CoInitialize(nullptr))) + x() { - throw std::runtime_error("Failed to initialize the component object model"); + if (FAILED(CoInitialize(nullptr))) + { + throw std::runtime_error("Failed to initialize the component object model"); + } } - } - ~_() - { - CoUninitialize(); - } - } __; + ~x() + { + CoUninitialize(); + } + } xx; + } } bool select_folder(std::string& out_folder, const std::string& title, const std::string& selected_folder) { + initialize_com(); + CComPtr file_dialog{}; - if(FAILED(CoCreateInstance(CLSID_FileOpenDialog, nullptr, CLSCTX_INPROC_SERVER, IID_PPV_ARGS(&file_dialog)))) + if (FAILED(CoCreateInstance(CLSID_FileOpenDialog, nullptr, CLSCTX_INPROC_SERVER, IID_PPV_ARGS(&file_dialog)))) { throw std::runtime_error("Failed to create co instance"); } DWORD dw_options; - if(FAILED(file_dialog->GetOptions(&dw_options))) + if (FAILED(file_dialog->GetOptions(&dw_options))) { throw std::runtime_error("Failed to get options"); } - if(FAILED(file_dialog->SetOptions(dw_options | FOS_PICKFOLDERS))) + if (FAILED(file_dialog->SetOptions(dw_options | FOS_PICKFOLDERS))) { throw std::runtime_error("Failed to set options"); } - std::wstring wide_title(title.begin(), title.end()); - if(FAILED(file_dialog->SetTitle(wide_title.data()))) + const std::wstring wide_title(title.begin(), title.end()); + if (FAILED(file_dialog->SetTitle(wide_title.data()))) { throw std::runtime_error("Failed to set title"); } @@ -69,7 +73,7 @@ namespace utils::com } IShellItem* shell_item = nullptr; - if(FAILED(SHCreateItemFromParsingName(wide_selected_folder.data(), NULL, IID_PPV_ARGS(&shell_item)))) + if (FAILED(SHCreateItemFromParsingName(wide_selected_folder.data(), NULL, IID_PPV_ARGS(&shell_item)))) { throw std::runtime_error("Failed to create item from parsing name"); } @@ -81,7 +85,7 @@ namespace utils::com } const auto result = file_dialog->Show(nullptr); - if(result == HRESULT_FROM_WIN32(ERROR_CANCELLED)) + if (result == HRESULT_FROM_WIN32(ERROR_CANCELLED)) { return false; } @@ -92,13 +96,13 @@ namespace utils::com } CComPtr result_item{}; - if(FAILED(file_dialog->GetResult(&result_item))) + if (FAILED(file_dialog->GetResult(&result_item))) { throw std::runtime_error("Failed to get result"); } PWSTR raw_path = nullptr; - if(FAILED(result_item->GetDisplayName(SIGDN_FILESYSPATH, &raw_path))) + if (FAILED(result_item->GetDisplayName(SIGDN_FILESYSPATH, &raw_path))) { throw std::runtime_error("Failed to get path display name"); } @@ -116,8 +120,11 @@ namespace utils::com CComPtr create_progress_dialog() { + initialize_com(); + CComPtr progress_dialog{}; - if(FAILED(CoCreateInstance(CLSID_ProgressDialog, nullptr, CLSCTX_INPROC_SERVER, IID_PPV_ARGS(&progress_dialog)))) + if (FAILED( + CoCreateInstance(CLSID_ProgressDialog, nullptr, CLSCTX_INPROC_SERVER, IID_PPV_ARGS(&progress_dialog)))) { throw std::runtime_error("Failed to create co instance"); } diff --git a/src/common/utils/compression.cpp b/src/common/utils/compression.cpp index 7354c6a..d2a0d17 100644 --- a/src/common/utils/compression.cpp +++ b/src/common/utils/compression.cpp @@ -3,6 +3,7 @@ #include #include +#include #include "io.hpp" #include "finally.hpp" @@ -164,5 +165,235 @@ namespace utils::compression return true; } + + namespace + { + std::optional> read_zip_file_entry(unzFile& zip_file) + { + char filename[1024]{}; + unz_file_info file_info{}; + if (unzGetCurrentFileInfo(zip_file, &file_info, filename, sizeof(filename), nullptr, 0, nullptr, 0) != + UNZ_OK) + { + return {}; + } + + if (unzOpenCurrentFile(zip_file) != UNZ_OK) + { + return {}; + } + + auto _ = finally([&zip_file] + { + unzCloseCurrentFile(zip_file); + }); + + int error = UNZ_OK; + std::string out_buffer{}; + static thread_local char buffer[0x2000]; + + do + { + error = unzReadCurrentFile(zip_file, buffer, sizeof(buffer)); + if (error < 0) + { + return {}; + } + + // Write data to file. + if (error > 0) + { + out_buffer.append(buffer, error); + } + } + while (error > 0); + + return std::pair{filename, out_buffer}; + } + + class memory_file + { + public: + memory_file(const std::string& data) + : data_(data) + { + func_def_.opaque = this; + func_def_.zopen64_file = open_file_static; + func_def_.zseek64_file = seek_file_static; + func_def_.ztell64_file = tell_file_static; + func_def_.zread_file = read_file_static; + func_def_.zwrite_file = write_file_static; + func_def_.zclose_file = close_file_static; + func_def_.zerror_file = testerror_file_static; + } + + const char* get_name() const + { + return "blub"; + } + + zlib_filefunc64_def* get_func_def() + { + return &this->func_def_; + } + + private: + const std::string& data_; + size_t offset_{0}; + zlib_filefunc64_def func_def_{}; + + voidpf open_file(const void* filename, const int mode) const + { + if (mode != (ZLIB_FILEFUNC_MODE_READ | ZLIB_FILEFUNC_MODE_EXISTING)) + { + return nullptr; + } + + if (strcmp(static_cast(filename), get_name()) != 0) + { + return nullptr; + } + + return reinterpret_cast(1); + } + + long seek_file(const voidpf stream, const ZPOS64_T offset, const int origin) + { + if (stream != reinterpret_cast(1)) + { + return -1; + } + + size_t target_base = this->data_.size(); + if (origin == ZLIB_FILEFUNC_SEEK_CUR) + { + target_base = this->offset_; + } + else if (origin == ZLIB_FILEFUNC_SEEK_SET) + { + target_base = 0; + } + + const auto target_offset = target_base + offset; + if (target_offset > this->data_.size()) + { + return -1; + } + + this->offset_ = target_offset; + return 0; + } + + ZPOS64_T tell_file(const voidpf stream) const + { + if (stream != reinterpret_cast(1)) + { + return static_cast(-1); + } + + return this->offset_; + } + + uLong read_file(const voidpf stream, void* buf, const uLong size) + { + if (stream != reinterpret_cast(1)) + { + return 0; + } + + const auto file_end = this->data_.size(); + const auto start = this->offset_; + const auto end = std::min(this->offset_ + size, file_end); + const auto length = end - start; + + memcpy(buf, this->data_.data() + start, length); + this->offset_ = end; + + return static_cast(length); + } + + static voidpf open_file_static(const voidpf opaque, const void* filename, const int mode) + { + return static_cast(opaque)->open_file(filename, mode); + } + + static long seek_file_static(const voidpf opaque, const voidpf stream, const ZPOS64_T offset, + const int origin) + { + return static_cast(opaque)->seek_file(stream, offset, origin); + } + + static ZPOS64_T tell_file_static(const voidpf opaque, const voidpf stream) + { + return static_cast(opaque)->tell_file(stream); + } + + static uLong read_file_static(const voidpf opaque, const voidpf stream, void* buf, const uLong size) + { + return static_cast(opaque)->read_file(stream, buf, size); + } + + static uLong write_file_static(voidpf, voidpf, const void*, uLong) + { + return 0; + } + + static int close_file_static(voidpf, voidpf) + { + return 0; + } + + static int testerror_file_static(voidpf, voidpf) + { + return 0; + } + }; + } + + std::unordered_map extract(const std::string& data) + { + memory_file mem_file(data); + + auto zip_file = unzOpen2_64(mem_file.get_name(), mem_file.get_func_def()); + auto _ = finally([&zip_file] + { + if (zip_file) + { + unzClose(zip_file); + } + }); + + if (!zip_file) + { + return {}; + } + + unz_global_info global_info{}; + if (unzGetGlobalInfo(zip_file, &global_info) != UNZ_OK) + { + return {}; + } + + std::unordered_map files{}; + files.reserve(global_info.number_entry); + + for (auto i = 0ul; i < global_info.number_entry; ++i) + { + if (i > 0 && unzGoToNextFile(zip_file) != UNZ_OK) + { + break; + } + + auto file = read_zip_file_entry(zip_file); + if (!file) + { + continue; + } + + files[std::move(file->first)] = std::move(file->second); + } + + return files; + } } } diff --git a/src/common/utils/compression.hpp b/src/common/utils/compression.hpp index dfe36ad..ad79734 100644 --- a/src/common/utils/compression.hpp +++ b/src/common/utils/compression.hpp @@ -24,5 +24,7 @@ namespace utils::compression private: std::unordered_map files_; }; + + std::unordered_map extract(const std::string& data); } }; diff --git a/src/common/utils/http.cpp b/src/common/utils/http.cpp index 40795e8..70ae26f 100644 --- a/src/common/utils/http.cpp +++ b/src/common/utils/http.cpp @@ -10,11 +10,11 @@ namespace utils::http { struct progress_helper { - const std::function* callback{}; + const std::function* callback{}; std::exception_ptr exception{}; }; - int progress_callback(void *clientp, const curl_off_t /*dltotal*/, const curl_off_t dlnow, const curl_off_t /*ultotal*/, const curl_off_t /*ulnow*/) + int progress_callback(void *clientp, const curl_off_t dltotal, const curl_off_t dlnow, const curl_off_t /*ultotal*/, const curl_off_t /*ulnow*/) { auto* helper = static_cast(clientp); @@ -22,7 +22,7 @@ namespace utils::http { if (*helper->callback) { - (*helper->callback)(dlnow); + (*helper->callback)(dltotal, dlnow); } } catch(...) @@ -44,7 +44,7 @@ namespace utils::http } } - std::optional get_data(const std::string& url, const headers& headers, const std::function& callback) + std::optional get_data(const std::string& url, const headers& headers, const std::function& callback) { curl_slist* header_list = nullptr; auto* curl = curl_easy_init(); diff --git a/src/common/utils/http.hpp b/src/common/utils/http.hpp index b5248bc..be9cf08 100644 --- a/src/common/utils/http.hpp +++ b/src/common/utils/http.hpp @@ -8,6 +8,6 @@ namespace utils::http { using headers = std::unordered_map; - std::optional get_data(const std::string& url, const headers& headers = {}, const std::function& callback = {}); + std::optional get_data(const std::string& url, const headers& headers = {}, const std::function& callback = {}); std::future> get_data_async(const std::string& url, const headers& headers = {}); } diff --git a/src/common/utils/io.cpp b/src/common/utils/io.cpp index 4968f44..a03d1bb 100644 --- a/src/common/utils/io.cpp +++ b/src/common/utils/io.cpp @@ -6,7 +6,12 @@ namespace utils::io { bool remove_file(const std::string& file) { - return DeleteFileA(file.data()) == TRUE; + if(DeleteFileA(file.data()) != FALSE) + { + return true; + } + + return GetLastError() == ERROR_FILE_NOT_FOUND; } bool move_file(const std::string& src, const std::string& target) diff --git a/src/common/utils/progress_ui.cpp b/src/common/utils/progress_ui.cpp new file mode 100644 index 0000000..4f6b6ea --- /dev/null +++ b/src/common/utils/progress_ui.cpp @@ -0,0 +1,45 @@ +#include "progress_ui.hpp" + +#include + +namespace utils +{ + progress_ui::progress_ui() + { + this->dialog_ = utils::com::create_progress_dialog(); + if (!this->dialog_) + { + throw std::runtime_error{"Failed to create dialog"}; + } + } + + progress_ui::~progress_ui() + { + this->dialog_->StopProgressDialog(); + } + + void progress_ui::show(const bool marquee) const + { + this->dialog_->StartProgressDialog(nullptr, nullptr, PROGDLG_AUTOTIME | (marquee ? PROGDLG_MARQUEEPROGRESS : 0), nullptr); + } + + void progress_ui::set_progress(const size_t current, const size_t max) const + { + this->dialog_->SetProgress64(current, max); + } + + void progress_ui::set_line(const int line, const std::string& text) const + { + this->dialog_->SetLine(line, utils::string::convert(text).data(), false, nullptr); + } + + void progress_ui::set_title(const std::string& title) const + { + this->dialog_->SetTitle(utils::string::convert(title).data()); + } + + bool progress_ui::is_cancelled() const + { + return this->dialog_->HasUserCancelled(); + } +} diff --git a/src/common/utils/progress_ui.hpp b/src/common/utils/progress_ui.hpp new file mode 100644 index 0000000..75b3de5 --- /dev/null +++ b/src/common/utils/progress_ui.hpp @@ -0,0 +1,24 @@ +#pragma once + +#include "com.hpp" + +namespace utils +{ + class progress_ui + { + public: + progress_ui(); + ~progress_ui(); + + void show(bool marquee) const; + + void set_progress(size_t current, size_t max) const; + void set_line(int line, const std::string& text) const; + void set_title(const std::string& title) const; + + bool is_cancelled() const; + + private: + CComPtr dialog_{}; + }; +}