Skip to content

Commit

Permalink
Add mod hashing, improve download protocol (#129)
Browse files Browse the repository at this point in the history
  • Loading branch information
lionkor authored Oct 4, 2024
2 parents 7944e9d + dc13e4a commit 768f11f
Showing 1 changed file with 255 additions and 3 deletions.
258 changes: 255 additions & 3 deletions src/Network/Resources.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
#include <iomanip>
#include <ios>
#include <mutex>
#include <nlohmann/json.hpp>
#include <openssl/err.h>
#include <openssl/evp.h>

#if defined(_WIN32)
#include <ws2tcpip.h>
Expand All @@ -25,6 +28,7 @@

#include "Logger.h"
#include "Startup.h"
#include <Utils.h>
#include <atomic>
#include <cmath>
#include <cstring>
Expand All @@ -33,7 +37,8 @@
#include <future>
#include <iostream>
#include <thread>
#include <Utils.h>

#include "hashpp.h"

namespace fs = std::filesystem;

Expand Down Expand Up @@ -214,6 +219,35 @@ SOCKET InitDSock() {
return DSock;
}

std::vector<char> SingleNormalDownload(SOCKET MSock, uint64_t Size, const std::string& Name) {
DownloadSpeed = 0;

uint64_t GRcv = 0;

std::thread Au([&] { AsyncUpdate(GRcv, Size, Name); });

const std::vector<char> MData = TCPRcvRaw(MSock, GRcv, Size);

if (MData.empty()) {
KillSocket(MSock);
Terminate = true;
Au.join();
return {};
}

// ensure that GRcv is good before joining the async update thread
GRcv = MData.size();
if (GRcv != Size) {
error("Something went wrong during download; didn't get enough data. Expected " + std::to_string(Size) + " bytes, got " + std::to_string(GRcv) + " bytes instead");
Terminate = true;
Au.join();
return {};
}

Au.join();
return MData;
}

std::vector<char> MultiDownload(SOCKET MSock, SOCKET DSock, uint64_t Size, const std::string& Name) {
DownloadSpeed = 0;

Expand Down Expand Up @@ -253,7 +287,7 @@ std::vector<char> MultiDownload(SOCKET MSock, SOCKET DSock, uint64_t Size, const

Au.join();

std::vector<char> Result{};
std::vector<char> Result {};
Result.insert(Result.begin(), MData.begin(), MData.end());
Result.insert(Result.end(), DData.begin(), DData.end());
return Result;
Expand All @@ -265,13 +299,230 @@ void InvalidResource(const std::string& File) {
Terminate = true;
}

std::string GetSha256HashReallyFast(const std::string& filename) {
try {
EVP_MD_CTX* mdctx;
const EVP_MD* md;
uint8_t sha256_value[EVP_MAX_MD_SIZE];
md = EVP_sha256();
if (md == nullptr) {
throw std::runtime_error("EVP_sha256() failed");
}

mdctx = EVP_MD_CTX_new();
if (mdctx == nullptr) {
throw std::runtime_error("EVP_MD_CTX_new() failed");
}
if (!EVP_DigestInit_ex2(mdctx, md, NULL)) {
EVP_MD_CTX_free(mdctx);
throw std::runtime_error("EVP_DigestInit_ex2() failed");
}

std::ifstream stream(filename, std::ios::binary);

const size_t FileSize = std::filesystem::file_size(filename);
size_t Read = 0;
std::vector<char> Data;
while (Read < FileSize) {
Data.resize(size_t(std::min<size_t>(FileSize - Read, 4096)));
size_t RealDataSize = Data.size();
stream.read(Data.data(), std::streamsize(Data.size()));
if (stream.eof() || stream.fail()) {
RealDataSize = size_t(stream.gcount());
}
Data.resize(RealDataSize);
if (RealDataSize == 0) {
break;
}
if (RealDataSize > 0 && !EVP_DigestUpdate(mdctx, Data.data(), Data.size())) {
EVP_MD_CTX_free(mdctx);
throw std::runtime_error("EVP_DigestUpdate() failed");
}
Read += RealDataSize;
}
unsigned int sha256_len = 0;
if (!EVP_DigestFinal_ex(mdctx, sha256_value, &sha256_len)) {
EVP_MD_CTX_free(mdctx);
throw std::runtime_error("EVP_DigestFinal_ex() failed");
}
EVP_MD_CTX_free(mdctx);

std::string result;
for (size_t i = 0; i < sha256_len; i++) {
char buf[3];
sprintf(buf, "%02x", sha256_value[i]);
buf[2] = 0;
result += buf;
}
return result;
} catch (const std::exception& e) {
error("Sha256 hashing of '" + filename + "' failed: " + e.what());
return "";
}
}

struct ModInfo {
static std::vector<ModInfo> ParseModInfosFromPacket(const std::string& packet) {
std::vector<ModInfo> modInfos;
try {
auto json = nlohmann::json::parse(packet);
for (const auto& entry : json) {
ModInfo modInfo {
.FileName = entry["file_name"],
.FileSize = entry["file_size"],
.Hash = entry["hash"],
.HashAlgorithm = entry["hash_algorithm"],
};
modInfos.push_back(modInfo);
}
} catch (const std::exception& e) {
debug(std::string("Failed to receive mod list: ") + e.what());
error("Failed to receive mod list!");
// TODO: Cry and die
}
return modInfos;
}
std::string FileName;
size_t FileSize;
std::string Hash;
std::string HashAlgorithm;
};

void NewSyncResources(SOCKET Sock, const std::string& Mods, const std::vector<ModInfo> ModInfos) {
if (!SecurityWarning())
return;

info("Checking Resources...");

CheckForDir();

std::string t;
for (const auto& mod : ModInfos) {
t += mod.FileName + ";";
}

if (t.empty())
CoreSend("L");
else
CoreSend("L" + t);
t.clear();

info("Syncing...");

int ModNo = 1;
int TotalMods = ModInfos.size();
for (auto ModInfoIter = ModInfos.begin(), AlsoModInfoIter = ModInfos.begin(); ModInfoIter != ModInfos.end() && !Terminate; ++ModInfoIter, ++AlsoModInfoIter) {
if (ModInfoIter->Hash.length() < 8 || ModInfoIter->HashAlgorithm != "sha256") {
error("Unsupported hash algorithm or invalid hash for '" + ModInfoIter->FileName + "'");
Terminate = true;
return;
}
auto FileName = std::filesystem::path(ModInfoIter->FileName).stem().string() + "-" + ModInfoIter->Hash.substr(0, 8) + std::filesystem::path(ModInfoIter->FileName).extension().string();
auto PathToSaveTo = (fs::path(CachingDirectory) / FileName).string();
if (fs::exists(PathToSaveTo) && GetSha256HashReallyFast(PathToSaveTo) == ModInfoIter->Hash) {
debug("Mod '" + FileName + "' found in cache");
std::this_thread::sleep_for(std::chrono::milliseconds(50));
try {
if (!fs::exists(GetGamePath() + "mods/multiplayer")) {
fs::create_directories(GetGamePath() + "mods/multiplayer");
}
auto modname = ModInfoIter->FileName;
#if defined(__linux__)
// Linux version of the game doesnt support uppercase letters in mod names
for (char& c : modname) {
c = ::tolower(c);
}
#endif
debug("Mod name: " + modname);
auto name = std::filesystem::path(GetGamePath()) / "mods/multiplayer" / modname;
std::string tmp_name = name.string();
tmp_name += ".tmp";

fs::copy_file(PathToSaveTo, tmp_name, fs::copy_options::overwrite_existing);
fs::rename(tmp_name, name);
} catch (std::exception& e) {
error("Failed copy to the mods folder! " + std::string(e.what()));
Terminate = true;
continue;
}
WaitForConfirm();
continue;
}
CheckForDir();
std::string FName = ModInfoIter->FileName;
do {
debug("Loading file '" + FName + "' to '" + PathToSaveTo + "'");
TCPSend("f" + ModInfoIter->FileName, Sock);

std::string Data = TCPRcv(Sock);
if (Data == "CO" || Terminate) {
Terminate = true;
UUl("Server cannot find " + FName);
break;
}

std::string Name = std::to_string(ModNo) + "/" + std::to_string(TotalMods) + ": " + FName;

std::vector<char> DownloadedFile = SingleNormalDownload(Sock, ModInfoIter->FileSize, Name);

if (Terminate)
break;
UpdateUl(false, std::to_string(ModNo) + "/" + std::to_string(TotalMods) + ": " + FName);

// 1. write downloaded file to disk
{
std::ofstream OutFile(PathToSaveTo, std::ios::binary | std::ios::trunc);
OutFile.write(DownloadedFile.data(), DownloadedFile.size());
OutFile.flush();
}
// 2. verify size
if (std::filesystem::file_size(PathToSaveTo) != DownloadedFile.size()) {
error("Failed to write the entire file '" + PathToSaveTo + "' correctly (file size mismatch)");
Terminate = true;
}
} while (fs::file_size(PathToSaveTo) != ModInfoIter->FileSize && !Terminate);
if (!Terminate) {
if (!fs::exists(GetGamePath() + "mods/multiplayer")) {
fs::create_directories(GetGamePath() + "mods/multiplayer");
}

// Linux version of the game doesnt support uppercase letters in mod names
#if defined(__linux__)
for (char& c : FName) {
c = ::tolower(c);
}
#endif

fs::copy_file(PathToSaveTo, std::filesystem::path(GetGamePath()) / "mods/multiplayer" / FName, fs::copy_options::overwrite_existing);
}
WaitForConfirm();
++ModNo;
}

if (!Terminate) {
TCPSend("Done", Sock);
info("Done!");
} else {
UlStatus = "Ulstart";
info("Connection Terminated!");
}
}

void SyncResources(SOCKET Sock) {
std::string Ret = Auth(Sock);

auto ModInfos = ModInfo::ParseModInfosFromPacket(Ret);

if (!ModInfos.empty()) {
NewSyncResources(Sock, Ret, ModInfos);
return;
}

if (Ret.empty())
return;

if (!SecurityWarning())
return;
return;

info("Checking Resources...");
CheckForDir();
Expand Down Expand Up @@ -396,6 +647,7 @@ void SyncResources(SOCKET Sock) {
}
WaitForConfirm();
}

KillSocket(DSock);
if (!Terminate) {
TCPSend("Done", Sock);
Expand Down

0 comments on commit 768f11f

Please sign in to comment.