Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#2342: Change schema to allow for id or seq_id #2343

Merged
merged 10 commits into from
Sep 9, 2024
11 changes: 6 additions & 5 deletions scripts/JSON_data_files_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,13 +433,14 @@ def validate_comm_links(all_jsons):
task_ids = set()

for data in all_jsons:
tasks = data["phases"][n]["tasks"]
id_key = "id" if "id" in tasks[0]["entity"] else "seq_id"
task_ids.update({int(task["entity"][id_key]) for task in tasks})

if data["phases"][n].get("communications") is not None:
comms = data["phases"][n]["communications"]
comm_ids.update({int(comm["from"]["id"]) for comm in comms})
comm_ids.update({int(comm["to"]["id"]) for comm in comms})

tasks = data["phases"][n]["tasks"]
task_ids.update({int(task["entity"]["id"]) for task in tasks})
comm_ids.update({int(comm["from"][id_key]) for comm in comms})
comm_ids.update({int(comm["to"][id_key]) for comm in comms})

if not comm_ids.issubset(task_ids):
logging.error(
Expand Down
27 changes: 18 additions & 9 deletions scripts/LBDatafile_schema.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
from schema import And, Optional, Schema

def validate_id_and_seq_id(field):
"""Ensure that either seq_id or id is provided."""
if 'seq_id' not in field and 'id' not in field:
raise ValueError('Either id (bit-encoded) or seq_id must be provided.')
return field

LBDatafile_schema = Schema(
{
Optional('type'): And(str, "LBDatafile", error="'LBDatafile' must be chosen."),
Expand Down Expand Up @@ -30,15 +36,16 @@
'id': int,
'tasks': [
{
'entity': {
'entity': And({
Optional('collection_id'): int,
'home': int,
'id': int,
Optional('id'): int,
Optional('seq_id'): int,
Optional('index'): [int],
'type': str,
'migratable': bool,
Optional('objgroup_id'): int
},
}, validate_id_and_seq_id),
'node': int,
'resource': str,
Optional('subphases'): [
Expand All @@ -55,25 +62,27 @@
Optional('communications'): [
{
'type': str,
'to': {
'to': And({
'type': str,
'id': int,
Optional('id'): int,
Optional('seq_id'): int,
Optional('home'): int,
Optional('collection_id'): int,
Optional('migratable'): bool,
Optional('index'): [int],
Optional('objgroup_id'): int,
},
}, validate_id_and_seq_id),
'messages': int,
'from': {
'from': And({
'type': str,
'id': int,
Optional('id'): int,
Optional('seq_id'): int,
Optional('home'): int,
Optional('collection_id'): int,
Optional('migratable'): bool,
Optional('index'): [int],
Optional('objgroup_id'): int,
},
}, validate_id_and_seq_id),
'bytes': float
}
],
Expand Down
105 changes: 65 additions & 40 deletions src/vt/vrt/collection/balance/lb_data_holder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,54 @@
//@HEADER
*/

#include "vt/vrt/collection/balance/lb_data_holder.h"
#include "vt/context/context.h"
#include "vt/elm/elm_id_bits.h"
#include "vt/vrt/collection/balance/lb_data_holder.h"

#include <nlohmann/json.hpp>

namespace vt { namespace vrt { namespace collection { namespace balance {

void LBDataHolder::getObjectFromJsonField_(
nlohmann::json const& field, nlohmann::json& object, bool& is_bitpacked,
bool& is_collection) {
if (field.find("id") != field.end()) {
object = field["id"];
is_bitpacked = true;
} else {
object = field["seq_id"];
is_bitpacked = false;
}
vtAssertExpr(object.is_number());
if (field.find("collection_id") != field.end()) {
is_collection = true;
} else {
is_collection = false;
}
}
nlslatt marked this conversation as resolved.
Show resolved Hide resolved

ElementIDStruct
LBDataHolder::getElmFromCommObject_(
nlohmann::json const& field) const {
// Get the object's id and determine if it is bit-encoded
nlohmann::json object;
bool is_bitpacked, is_collection;
getObjectFromJsonField_(field, object, is_bitpacked, is_collection);

// Create elm with encoded data
ElementIDStruct elm;
if (is_collection and not is_bitpacked) {
int home = field["home"];
bool is_migratable = field["migratable"];
elm = elm::ElmIDBits::createCollectionImpl(
is_migratable, static_cast<ElementIDType>(object), home, this_node_);
} else {
elm = ElementIDStruct{object, this_node_};
}

return elm;
}

void LBDataHolder::outputEntity(nlohmann::json& j, ElementIDStruct const& id) const {
j["type"] = "object";
j["id"] = id.id;
Expand Down Expand Up @@ -278,7 +318,7 @@ std::unique_ptr<nlohmann::json> LBDataHolder::toJson(PhaseType phase) const {

LBDataHolder::LBDataHolder(nlohmann::json const& j)
{
auto this_node = theContext()->getNode();
this_node_ = theContext()->getNode();

// read metadata for skipped and identical phases
readMetadata(j);
Expand All @@ -298,41 +338,35 @@ LBDataHolder::LBDataHolder(nlohmann::json const& j)
auto time = task["time"];
auto etype = task["entity"]["type"];
auto home = task["entity"]["home"];
bool migratable = task["entity"]["migratable"];
bool is_migratable = task["entity"]["migratable"];

vtAssertExpr(time.is_number());
vtAssertExpr(node.is_number());

if (etype == "object") {
auto object = task["entity"]["id"];
vtAssertExpr(object.is_number());

auto elm = ElementIDStruct{object, node};
nlohmann::json object;
bool is_bitpacked, is_collection;
getObjectFromJsonField_(task["entity"], object, is_bitpacked, is_collection);

// Create elm
ElementIDStruct elm = is_collection and not is_bitpacked
? elm::ElmIDBits::createCollectionImpl(
is_migratable, static_cast<ElementIDType>(object), home, this_node_)
: ElementIDStruct{object, this_node_};
this->node_data_[id][elm].whole_phase_load = time;

if (
task["entity"].find("collection_id") != task["entity"].end() and
task["entity"].find("index") != task["entity"].end()
) {
using Field = uint64_t;
auto strippedObject = BitPackerType::getField<
vt::elm::eElmIDProxyBitsNonObjGroup::ID,
vt::elm::elm_id_num_bits,
Field
>(static_cast<Field>(object));
elm = elm::ElmIDBits::createCollectionImpl(migratable,
strippedObject,
home,
node);
if (is_collection) {
auto cid = task["entity"]["collection_id"];
auto idx = task["entity"]["index"];
if (cid.is_number() && idx.is_array()) {
std::vector<uint64_t> arr = idx;
auto proxy = static_cast<VirtualProxyType>(cid);
this->node_idx_[elm] = std::make_tuple(proxy, arr);
if (task["entity"].find("index") != task["entity"].end()) {
auto idx = task["entity"]["index"];
if (cid.is_number() && idx.is_array()) {
std::vector<uint64_t> arr = idx;
auto proxy = static_cast<VirtualProxyType>(cid);
this->node_idx_[elm] = std::make_tuple(proxy, arr);
}
}
}

this->node_data_[id][elm].whole_phase_load = time;

if (task.find("subphases") != task.end()) {
auto subphases = task["subphases"];
Expand Down Expand Up @@ -397,13 +431,8 @@ LBDataHolder::LBDataHolder(nlohmann::json const& j)
vtAssertExpr(comm["from"]["type"] == "object");
vtAssertExpr(comm["to"]["type"] == "object");

auto from_object = comm["from"]["id"];
vtAssertExpr(from_object.is_number());
auto from_elm = ElementIDStruct{from_object, this_node};

auto to_object = comm["to"]["id"];
vtAssertExpr(to_object.is_number());
auto to_elm = ElementIDStruct{to_object, this_node};
auto from_elm = getElmFromCommObject_(comm["from"]);
auto to_elm = getElmFromCommObject_(comm["to"]);

CommKey key(
CommKey::CollectionTag{},
Expand All @@ -420,9 +449,7 @@ LBDataHolder::LBDataHolder(nlohmann::json const& j)
auto from_node = comm["from"]["id"];
vtAssertExpr(from_node.is_number());

auto to_object = comm["to"]["id"];
vtAssertExpr(to_object.is_number());
auto to_elm = ElementIDStruct{to_object, this_node};
auto to_elm = getElmFromCommObject_(comm["to"]);

CommKey key(
CommKey::NodeToCollectionTag{},
Expand All @@ -437,9 +464,7 @@ LBDataHolder::LBDataHolder(nlohmann::json const& j)
vtAssertExpr(comm["from"]["type"] == "object");
vtAssertExpr(comm["to"]["type"] == "node");

auto from_object = comm["from"]["id"];
vtAssertExpr(from_object.is_number());
auto from_elm = ElementIDStruct{from_object, this_node};
auto from_elm = getElmFromCommObject_(comm["from"]);

auto to_node = comm["to"]["id"];
vtAssertExpr(to_node.is_number());
Expand Down
25 changes: 25 additions & 0 deletions src/vt/vrt/collection/balance/lb_data_holder.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,29 @@ struct LBDataHolder {

void addInitialTask(nlohmann::json& j, std::size_t n) const;

/**
* \brief Determine the object ID from the tasks or communication field of
* input JSON
*
* \param[in] field the json field containing an object ID
* \param[in] object empty json object to be populated with the object's ID
* \param[in] is_bitpacked empty bool to be populated with whether or not
* the ID is bit-encoded
* \param[in] is_collection empty bool to be populated with whether
* or not the object belongs to a collection
*/
static void getObjectFromJsonField_(
nlohmann::json const& field, nlohmann::json& object,
bool& is_bitpacked, bool& is_collection);

/**
* \brief Create an ElementIDStruct for the communication object
*
* \param[in] field the communication field for the desired object
* e.g. communications["to"] or communications["from"]
*/
ElementIDStruct getElmFromCommObject_(nlohmann::json const& field) const;

/**
* \brief Read the LB phase's metadata
*
Expand All @@ -135,6 +158,8 @@ struct LBDataHolder {
void readMetadata(nlohmann::json const& j);

public:
/// The current node
NodeType this_node_ = vt::uninitialized_destination;
/// Node attributes for the current rank
ElmUserDataType rank_attributes_;
/// Node timings for each local object
Expand Down
Loading