Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[WIP] passing ndarrays to acceptSubgraph API #17564

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions include/mxnet/lib_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -917,10 +917,15 @@ typedef int (*partCallSupportedOps_t)(supportedOps_t supportedOps, const char *j
const char* const* opt_vals, int num_opts);

#define MXLIB_PARTCALLACCEPTSUBGRAPH_STR "_partCallAcceptSubgraph"
typedef int (*partCallAcceptSubgraph_t)(acceptSubgraph_t acceptSubgraph, const char *json,
int subgraph_id, int *accept, const char* const* opt_keys,
typedef int (*partCallAcceptSubgraph_t)(acceptSubgraph_t acceptSubgraph,
const char *json, int subgraph_id,
int *accept, const char* const* opt_keys,
const char* const* opt_vals, int num_opts,
char*** attr_keys, char*** attr_vals, int *num_attrs);
char*** attr_keys, char*** attr_vals,
int *num_attrs, const char* const* in_args_chars,
void* const* in_args_data,
const int64_t* const *in_args_shapes,
const int* in_args_dims, const int* in_args_types);

#define MXLIB_INITIALIZE_STR "initialize"
typedef int (*initialize_t)(int version);
Expand Down Expand Up @@ -1283,7 +1288,10 @@ extern "C" {
_partCallAcceptSubgraph(acceptSubgraph_t acceptSubgraph, const char *json,
int subgraph_id, int *accept, const char* const* opt_keys,
const char* const* opt_vals, int num_opts,
char*** attr_keys, char*** attr_vals, int *num_attrs) {
char*** attr_keys, char*** attr_vals, int *num_attrs,
const char* const* in_args_chars, void* const* in_args_data,
const int64_t* const* in_args_shapes, int* in_args_dims,
int* in_args_types) {
std::string subgraph_json(json);
bool accept_bool = false;
// create map of attributes from list
Expand Down
6 changes: 5 additions & 1 deletion src/c_api/c_api_symbolic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1351,8 +1351,9 @@ int MXOptimizeForBackend(SymbolHandle sym_handle,
nnvm::Symbol *sym = static_cast<nnvm::Symbol *>(sym_handle);
*s = sym->Copy();
nnvm::Graph g = Symbol2Graph(*s);
NDArray **in_args_ptr = nullptr;
if (len) {
NDArray **in_args_ptr = reinterpret_cast<NDArray**>(in_args_handle);
in_args_ptr = reinterpret_cast<NDArray**>(in_args_handle);
Context default_ctx = Context::Create(static_cast<Context::DeviceType>(dev_type), 0);
mxnet::ShapeVector arg_shapes(len);
nnvm::DTypeVector arg_dtypes(len);
Expand Down Expand Up @@ -1382,6 +1383,9 @@ int MXOptimizeForBackend(SymbolHandle sym_handle,
g.GetAttr<StorageTypeVector>("storage_type"));
}
}
g.attrs["args"] = std::make_shared<nnvm::any>(in_args_ptr);
std::vector<std::string> names = sym->ListInputNames(nnvm::Symbol::ListInputOption(1));
g.attrs["arg_names"] = std::make_shared<nnvm::any>(names);
std::vector<std::pair<std::string, std::string>> options_map;
for (mx_uint i = 0; i < num_options; ++i) {
options_map.emplace_back(keys[i], vals[i]);
Expand Down
36 changes: 33 additions & 3 deletions src/operator/subgraph/partitioner/custom_subgraph_property.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,27 @@ class CustomSubgraphProperty: public SubgraphProperty {
const std::vector<std::pair<std::string, std::string>>& options_map) {
// clear supported_nodes to remove state from previous calls
supported_nodes.clear();

in_args_ptr = g.GetAttr<NDArray**>("args");
in_args_names = g.GetAttr<std::vector<std::string>>("arg_names");

for(std::string s : in_args_names) {
in_args_chars.push_back(s.c_str());
}

in_args_data.clear();
in_args_shapes.clear();
in_args_dims.clear();
in_args_types.clear();

// convert NDarrays to constituent parts
for (size_t i = 0; i < in_args_names.size(); i++) {
in_args_data.push_back(in_args_ptr[i]->data().dptr_);
in_args_shapes.push_back(in_args_ptr[i]->shape().data());
in_args_dims.push_back(in_args_ptr[i]->shape().ndim());
in_args_types.push_back(in_args_ptr[i]->dtype());
}

// remove all graph attrs, some cannot be saved to json
nnvm::Graph graph = std::move(g);
graph.attrs.clear();
Expand Down Expand Up @@ -189,9 +209,12 @@ class CustomSubgraphProperty: public SubgraphProperty {

std::string subgraph_json = nnvm::pass::SaveJSON(g);
CHECK(call_accept_subgraph_(accept_subgraph_, subgraph_json.c_str(),
subgraph_id, &accept, opt_keys_.data(),
opt_vals_.data(), opt_keys_.size(),
&attr_keys, &attr_vals, &num_attr))
subgraph_id, &accept, opt_keys_.data(),
opt_vals_.data(), opt_keys_.size(),
&attr_keys, &attr_vals, &num_attr,
in_args_chars.data(), in_args_data.data(),
in_args_shapes.data(), in_args_dims.data(),
in_args_types.data()))
<< "Error calling accept_subgraph for '" << subgraph_prop << "'";
}
if (accept) {
Expand Down Expand Up @@ -228,6 +251,13 @@ class CustomSubgraphProperty: public SubgraphProperty {
std::string subgraph_op_name;
std::vector<std::pair<std::string, std::string>> options_map_;
std::vector<const char*> opt_keys_, opt_vals_;
NDArray **in_args_ptr;
std::vector<std::string> in_args_names;
std::vector<const char*> in_args_chars;
std::vector<void*> in_args_data;
std::vector<const int64_t *> in_args_shapes;
std::vector<int> in_args_dims;
std::vector<int> in_args_types;
};
} // namespace op
} // namespace mxnet
Expand Down