Skip to content

Commit

Permalink
Merge branch 'DOR-893_models_directory_write_permissions' into 'master'
Browse files Browse the repository at this point in the history
DOR-893: Fix eager model download write permission check

Closes DOR-893

See merge request machine-learning/dorado!1207
  • Loading branch information
HalfPhoton committed Sep 24, 2024
2 parents f15c0b3 + cb07be0 commit f156ae6
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 18 deletions.
2 changes: 1 addition & 1 deletion dorado/cli/model_resolution.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ inline std::optional<std::filesystem::path> get_models_directory(
std::exit(EXIT_FAILURE);
}
path = std::filesystem::canonical(path);
spdlog::debug("set models directory to: '{}'", path.u8string());
spdlog::debug("Set models directory to: '{}'.", path.u8string());
return path;
} else if (env_path != nullptr) {
auto path = std::filesystem::path(env_path);
Expand Down
10 changes: 9 additions & 1 deletion dorado/model_downloader/downloader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,15 @@ bool Downloader::validate_checksum(std::string_view data, const models::ModelInf
}

void Downloader::extract(const fs::path& archive) const {
elz::extractZip(archive, m_directory);
spdlog::trace("Extracting model archive: '{}'.", archive.u8string());

try {
elz::extractZip(archive, m_directory);
} catch (const elz::zip_exception& e) {
spdlog::error("Failed to unzip model archive: '{}'.", e.what());
throw;
}

fs::remove(archive);
}

Expand Down
23 changes: 17 additions & 6 deletions dorado/model_downloader/model_downloader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,32 +40,43 @@ bool download_models(const std::string& target_directory, const std::string& sel
}

std::filesystem::path ModelDownloader::get(const ModelInfo& model, const std::string& description) {
// parent_dir is either a temporary directory in the CWD or the users selection
const auto parent_dir = utils::get_downloads_path(m_models_dir);
const fs::path parent_dir =
m_models_dir.has_value() ? m_models_dir.value() : utils::create_temporary_directory();

// clang-tidy warns about performance-no-automatic-move if |temp_model_dir| is const. It should be treated as such though.
/*const*/ fs::path model_dir = parent_dir / model.name;

if (std::filesystem::exists(model_dir)) {
if (fs::exists(model_dir)) {
spdlog::trace("Found existing model at '{}'.", model_dir.u8string());
return model_dir;
}

if (m_models_dir.has_value()) {
spdlog::trace("Model does not exist at '{}' - downloading it instead.",
model_dir.u8string());
}

if (!utils::has_write_permission(parent_dir)) {
throw std::runtime_error("Failed to prepare model download directory");
}

if (!download_models(parent_dir.u8string(), model.name)) {
throw std::runtime_error("Failed to download + " + description + " model: " + model.name);
throw std::runtime_error("Failed to download + " + description + " model: '" + model.name +
"'.");
}

if (is_temporary()) {
// Check parent_dir is temp preventing unintentionally deleting work
if (parent_dir.filename().u8string().find(utils::TEMP_MODELS_DIR_PREFIX) == 0) {
spdlog::trace("Temporary {} model '{}' downloaded into: '{}'", description, model.name,
parent_dir.u8string());
m_temp_models.emplace(parent_dir);
} else {
spdlog::warn(
"Temporary {} model directory does not have the expected name '{}' at: '{}'",
description, utils::TEMP_MODELS_DIR_PREFIX, parent_dir.u8string());
}
}

spdlog::trace("Downloaded {} model into: '{}'.", description, model_dir.u8string());
return model_dir;
}

Expand Down
16 changes: 12 additions & 4 deletions dorado/models/model_complex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ ModelComplex ModelComplexParser::parse(const std::string& arg) {
spdlog::trace("Model option: '{}' unknown - assuming path", variant_str);
selection.model = ModelVariantPair{model_variant};
} else {
spdlog::trace("'{}' found variant: '{}' and version: '{}'", variant_str,
to_string(model_variant), to_string(version));
spdlog::trace("Model complex: '{}' found variant: '{}' and version: '{}'",
variant_str, to_string(model_variant), to_string(version));
selection.model = ModelVariantPair{model_variant, version};
}
} else {
Expand Down Expand Up @@ -135,9 +135,9 @@ ModelComplexSearch::ModelComplexSearch(const ModelComplex& complex,
: m_complex(complex),
m_chemistry(chemistry),
m_suggestions(suggestions),
m_simplex_model_info(simplex()) {}
m_simplex_model_info(resolve_simplex()) {}

ModelInfo ModelComplexSearch::simplex() const {
ModelInfo ModelComplexSearch::resolve_simplex() const {
if (m_complex.is_path()) {
throw std::logic_error(
"Cannot use model ModelComplexSearch with a simplex model complex which is a path");
Expand All @@ -146,6 +146,14 @@ ModelInfo ModelComplexSearch::simplex() const {
m_suggestions);
}

ModelInfo ModelComplexSearch::simplex() const {
if (m_complex.is_path()) {
throw std::logic_error(
"Cannot use model ModelComplexSearch with a simplex model complex which is a path");
}
return m_simplex_model_info;
}

ModelInfo ModelComplexSearch::stereo() const {
return find_model(stereo_models(), "stereo duplex", m_chemistry, m_simplex_model_info.simplex,
ModsVariantPair(), false);
Expand Down
8 changes: 5 additions & 3 deletions dorado/models/model_complex.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,12 @@ class ModelComplexParser {
// could be a actual model complex tries to find the ModelInfo from the models lib.
class ModelComplexSearch {
public:
ModelComplexSearch(const ModelComplex& selection, Chemistry chemsitry, bool suggestions);
ModelComplexSearch(const ModelComplex& selection, Chemistry chemistry, bool suggestions);
// Return the model complex
ModelComplex complex() { return m_complex; }
// Return the chemistry found
// Return the chemistry
Chemistry chemistry() { return m_chemistry; }
// Find a simplex model which matches the user's command and chemistry
// Return the simplex model found during initialisation
ModelInfo simplex() const;
// Find a stereo model which matches the chemistry
ModelInfo stereo() const;
Expand All @@ -59,6 +59,8 @@ class ModelComplexSearch {
std::vector<ModelInfo> simplex_mods() const;

private:
// Resolve the simplex model which matches the user's command and chemistry
ModelInfo resolve_simplex() const;
// The user's model complex input
const ModelComplex m_complex;
// If a ModelVariant was set, the chemistry (e.g. R10.4.1 / RNA004) is deduced from the
Expand Down
11 changes: 8 additions & 3 deletions dorado/models/models.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ ModelInfo find_model(const std::vector<ModelInfo>& models,
if (Chemistry::UNKNOWN == chemistry) {
throw std::runtime_error("Cannot get model without chemistry");
}
const auto matches = find_models(models, chemistry, model, mods);
const std::vector<ModelInfo> matches = find_models(models, chemistry, model, mods);

if (matches.empty()) {
spdlog::error("Failed to get {} model", description);
Expand All @@ -146,8 +146,13 @@ ModelInfo find_model(const std::vector<ModelInfo>& models,
throw std::runtime_error("No matches for " + format_msg(chemistry, model, mods));
}

// Get the only match or the latest model
return matches.back();
// Get the only match or the latest model as models are sorted in ascending version order
const ModelInfo& selection = matches.back();
if (matches.size() > 1) {
spdlog::trace("Selected {} model: '{}' from {} matches.", description, selection.name,
matches.size());
}
return selection;
}

std::vector<ModelInfo> find_models(const std::vector<ModelInfo>& models,
Expand Down

0 comments on commit f156ae6

Please sign in to comment.