Skip to content

Commit

Permalink
perf: Improve edge classifier performance (#2734)
Browse files Browse the repository at this point in the history
For the case of 0 chunks. The value of the chunk option is actually questionable, but this should be addressed at a different point.
  • Loading branch information
benjaminhuth authored Nov 28, 2023
1 parent 8f4162f commit 949678d
Showing 1 changed file with 39 additions and 21 deletions.
60 changes: 39 additions & 21 deletions Plugins/ExaTrkX/src/TorchEdgeClassifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,35 +51,53 @@ std::tuple<std::any, std::any, std::any> TorchEdgeClassifier::operator()(
auto nodes = std::any_cast<torch::Tensor>(inputNodes).to(device);
auto edgeList = std::any_cast<torch::Tensor>(inputEdges).to(device);

auto model = m_model->clone();
model.to(device);

if (m_cfg.numFeatures > nodes.size(1)) {
throw std::runtime_error("requested more features then available");
}

std::vector<at::Tensor> results;
results.reserve(m_cfg.nChunks);

auto edgeListTmp =
m_cfg.undirected ? torch::cat({edgeList, edgeList.flip(0)}, 1) : edgeList;

std::vector<torch::jit::IValue> inputTensors(2);
inputTensors[0] = m_cfg.numFeatures < nodes.size(1)
? nodes.index({Slice{}, Slice{None, m_cfg.numFeatures}})
: nodes;

const auto chunks = at::chunk(at::arange(edgeListTmp.size(1)), m_cfg.nChunks);
for (const auto& chunk : chunks) {
ACTS_VERBOSE("Process chunk");
inputTensors[1] = edgeListTmp.index({Slice(), chunk});

results.push_back(m_model->forward(inputTensors).toTensor());
results.back().squeeze_();
results.back().sigmoid_();
torch::Tensor output;

// Scope this to keep inference objects separate
{
auto edgeListTmp = m_cfg.undirected
? torch::cat({edgeList, edgeList.flip(0)}, 1)
: edgeList;

std::vector<torch::jit::IValue> inputTensors(2);
inputTensors[0] =
m_cfg.numFeatures < nodes.size(1)
? nodes.index({Slice{}, Slice{None, m_cfg.numFeatures}})
: nodes;

if (m_cfg.nChunks > 1) {
std::vector<at::Tensor> results;
results.reserve(m_cfg.nChunks);

auto chunks = at::chunk(edgeListTmp, m_cfg.nChunks, 1);
for (auto& chunk : chunks) {
ACTS_VERBOSE("Process chunk with shape" << chunk.sizes());
inputTensors[1] = chunk;

results.push_back(model.forward(inputTensors).toTensor());
results.back().squeeze_();
}

output = torch::cat(results);
} else {
inputTensors[1] = edgeListTmp;
output = model.forward(inputTensors).toTensor();
output.squeeze_();
}
}

auto output = torch::cat(results);
output.sigmoid_();

if (m_cfg.undirected) {
output = output.index({Slice(None, output.size(0) / 2)});
auto newSize = output.size(0) / 2;
output = output.index({Slice(None, newSize)});
}

ACTS_VERBOSE("Size after classifier: " << output.size(0));
Expand Down

0 comments on commit 949678d

Please sign in to comment.