Skip to content

Commit

Permalink
Prepare to generate code about tensor_inputs_ and tensor_outputs_
Browse files Browse the repository at this point in the history
  • Loading branch information
daquexian committed May 13, 2019
1 parent 2c00e2e commit 1ed7bb2
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 1 deletion.
5 changes: 5 additions & 0 deletions dnnlibrary/include/ModelBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <memory>
#include <numeric>
#include <optional>
#include <set>
#include <string>
#include <vector>

Expand Down Expand Up @@ -42,6 +43,10 @@ class ModelBuilder {
std::map<float, Index> float32_operand_map_;
std::map<float, Index> float32_as_tensor_operand_map_;
StrKeyMap<android::nn::wrapper::OperandType> operand_types_;
// tensor_inputs_ and tensor_outputs_ is to automatically determine the
// output of the model
std::set<std::string> tensor_inputs_;
std::set<std::string> tensor_outputs_;

uint32_t int32_missing_index = UINT32_MAX;
uint32_t float32_missing_index = UINT32_MAX;
Expand Down
5 changes: 4 additions & 1 deletion generate_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ def add_tensor_operand(operand):
if operand['predefined'] == 'optional_bias':
return add_optional_bias()
if operand['cpp_type'] == 'str':
return '''const auto {0}_idx = operand_indexes_.at({0});
return '''tensor_inputs.insert({0});
const auto {0}_idx = operand_indexes_.at({0});
input_indexes.push_back({0}_idx);'''.format(operand['name'])
elif operand['cpp_type'] == 'float':
return '''const auto {0}_idx = FillOperand("input_{0}_of_" + output, {{Type::TENSOR_FLOAT32, {{1}}}}, {0});
Expand All @@ -85,6 +86,7 @@ def add_tensor_operand(operand):
input_indexes.push_back({0}_idx);'''.format(operand['name'])
elif operand['cpp_type'] == 'str_list':
return '''for (const auto &x : {}) {{
tensor_inputs_.insert(x);
input_indexes.push_back(operand_indexes_.at(x));
}}'''.format(operand['name'])
else:
Expand Down Expand Up @@ -294,6 +296,7 @@ def generate_model_builder():
'AddOperation(ANEURALNETWORKS_{}, input_indexes, operand_type)[0];'.format(op['nnapi']))
cogout(
'''RegisterOperand(output, output_idx, operand_type);
tensor_outputs_.insert(output);
return output_idx;
}
'''
Expand Down

0 comments on commit 1ed7bb2

Please sign in to comment.