Skip to content

Commit

Permalink
Merge pull request #10 from andrzejnovak/switch
Browse files Browse the repository at this point in the history
Run TF only when features are filled
  • Loading branch information
andrzejnovak authored Nov 21, 2018
2 parents 92bb0d6 + 5703a71 commit 5f8c3a4
Showing 1 changed file with 29 additions and 24 deletions.
53 changes: 29 additions & 24 deletions RecoBTag/TensorFlow/plugins/DeepDoubleXTFJetTagsProducer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,10 @@ void DeepDoubleXTFJetTagsProducer::globalEndJob(const DeepDoubleXTFCache* cache)

void DeepDoubleXTFJetTagsProducer::produce(edm::Event& iEvent, const edm::EventSetup& iSetup)
{

edm::Handle<TagInfoCollection> tag_infos;
iEvent.getByToken(src_, tag_infos);

// initialize output collection
std::vector<std::unique_ptr<JetTagCollection>> output_tags;
for (std::size_t i=0; i < flav_pairs_.size(); i++) {
Expand All @@ -220,12 +220,17 @@ void DeepDoubleXTFJetTagsProducer::produce(edm::Event& iEvent, const edm::EventS
output_tags.emplace_back(std::make_unique<JetTagCollection>());
}
}

const int64_t n_jets = tag_infos->size();

// count jets to actually run inference on
int64_t n_jets = 0;
for (std::size_t i=0; i < tag_infos->size(); i++){
const auto & features = tag_infos->at(i).features();
if (!features.empty()) n_jets += 1 ;
}
// count all jets to generate output_tags for each and set to default (-1)
const int64_t n_jets_all = tag_infos->size();
// either all jets or one per batch for the time being
const int64_t n_batch_jets = batch_eval_ ? n_jets : 1;


std::vector<tensorflow::TensorShape> input_sizes {
{n_batch_jets, 1, 27}, // input_1 - global double-b features
{n_batch_jets, 60, 8}, // input_2 - charged pf
Expand All @@ -248,27 +253,28 @@ void DeepDoubleXTFJetTagsProducer::produce(edm::Event& iEvent, const edm::EventS
for (std::size_t i=0; i < lp_tensors_.size(); i++) {
input_tensors[input_sizes.size() + i] = tensorflow::NamedTensor(lp_names_[i], lp_tensors_[i]);
}

std::size_t n_batches = n_jets/n_batch_jets; // either 1 or n_jets
std::size_t n_batches = n_jets_all/n_batch_jets; // either 1 or n_jets
for (std::size_t batch_n=0; batch_n < n_batches; batch_n++) {

// tensors have to be zeroed before filling per batch
for (std::size_t i=0; i < input_sizes.size(); i++) {
input_tensors[i].second.flat<float>().setZero();
}

// don't run batch unless filled
bool run_this_batch = false;
// fill values of the input tensors
for (std::size_t jet_bn=0; jet_bn < (std::size_t) n_batch_jets; jet_bn++) {

// global jet index (jet_bn is the jet batch index)
std::size_t jet_n = batch_n*n_batch_jets + jet_bn;

// jet and other global features
// only fill if features not empty
const auto & features = tag_infos->at(jet_n).features();
if (features.empty()) continue ;
// if at least one jet has features, run inferences
run_this_batch = true;
// jet and other global features
db_tensor_filler(input_tensors.at(kGlobal).second, jet_bn, features);


// c_pf candidates
auto max_c_pf_n = std::min(features.c_pf_features.size(),
(std::size_t) input_sizes.at(kChargedCandidates).dim_size(1));
Expand All @@ -290,36 +296,35 @@ void DeepDoubleXTFJetTagsProducer::produce(edm::Event& iEvent, const edm::EventS
}
// run the session
std::vector<tensorflow::Tensor> outputs;
if (run_this_batch){
tensorflow::run(session_, input_tensors, output_names_, &outputs);

}
// set output values for flavour probs
for (std::size_t jet_bn=0; jet_bn < (std::size_t) n_batch_jets; jet_bn++) {

// global jet index (jet_bn is the jet batch index)
std::size_t jet_n = batch_n*n_batch_jets + jet_bn;

const auto & features = tag_infos->at(jet_n).features();

const auto & jet_ref = tag_infos->at(jet_n).jet();
for (std::size_t flav_n=0; flav_n < flav_pairs_.size(); flav_n++) {
const auto & flav_pair = flav_pairs_.at(flav_n);
float o_sum = 0.;
for (const unsigned int & ind : flav_pair.second) {
o_sum += outputs.at(kJetFlavour).matrix<float>()(jet_bn, ind);
}
if (!features.empty()) {
(*(output_tags.at(flav_n)))[jet_ref] = o_sum;
} else {
(*(output_tags.at(flav_n)))[jet_ref] = -1.;
if (!features.empty()) {
for (const unsigned int & ind : flav_pair.second) {
o_sum += outputs.at(kJetFlavour).matrix<float>()(jet_bn, ind);
}
(*(output_tags.at(flav_n)))[jet_ref] = o_sum;
} else {
(*(output_tags.at(flav_n)))[jet_ref] = -1.;
}
}
}
}
}

for (std::size_t i=0; i < flav_pairs_.size(); i++) {
iEvent.put(std::move(output_tags[i]), flav_pairs_.at(i).first);
}

}

//define this as a plug-in
Expand Down

0 comments on commit 5f8c3a4

Please sign in to comment.