Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unity][MSC][M2.1] Add pruner for model pruning #16186

Merged
merged 3 commits into from
Dec 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 42 additions & 14 deletions python/tvm/contrib/msc/core/codegen/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def get_base_h_code() -> str:
#include <cassert>
#include <fstream>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>

Expand Down Expand Up @@ -78,11 +79,19 @@ class DatasetReader {

bool ReadNext(void* buffers[], int num_datas = -1);

const std::vector<std::string> GetTensorNames() { return tensor_names_; }

size_t GetTensorSize(const std::string& name);

const std::string GetSaveName(const std::string& name);

private:
std::string folder_;
size_t max_size_;
size_t cur_cnt_;
std::vector<std::pair<std::string, size_t>> tensor_info_;
std::vector<std::string> tensor_names_;
std::unordered_map<std::string, std::string> save_names_;
std::unordered_map<std::string, size_t> tensor_sizes_;
};

} // namespace msc
Expand All @@ -102,10 +111,10 @@ def get_base_cc_code() -> str:
The base cc source.
"""

return """#include <algorithm>
#include <fstream>
return """#include "base.h"

#include "base.h"
#include <algorithm>
#include <fstream>

namespace tvm {
namespace contrib {
Expand All @@ -122,23 +131,31 @@ def get_base_cc_code() -> str:

DatasetReader::DatasetReader(const std::string& folder, int max_size) {
folder_ = folder;
const std::string info_file = folder_ + "/tensor_info";
const std::string info_file = folder_ + "/datas_info.txt";
std::ifstream input(info_file, std::ios::binary);
assert(input.is_open() && ("Failed to open file " + info_file).c_str());
std::string line;
while (getline(input, line)) {
// define name
int pos = line.find(" ");
assert(pos > 0 && ("Can not find space in line " + line).c_str());
const auto& name = line.substr(0, pos);
const auto& byte_size = line.substr(pos + 1, line.size());
tensor_info_.push_back(std::make_pair(name, static_cast<size_t>(std::stoi(byte_size))));
tensor_names_.push_back(name);
const auto& left = line.substr(pos + 1, line.size());
// define save_name
pos = left.find(" ");
assert(pos > 0 && ("Can not find space in left " + left).c_str());
save_names_[name] = left.substr(0, pos);
// define size
const auto& byte_size = left.substr(pos + 1, left.size());
tensor_sizes_[name] = static_cast<size_t>(std::stoi(byte_size));
}
size_t file_cnt = 0;
while (true) {
bool all_exists = true;
for (const auto& pair : tensor_info_) {
for (const auto& pair : save_names_) {
const auto& d_file =
folder_ + "/" + pair.first + "/batch_" + std::to_string(file_cnt) + ".bin";
folder_ + "/" + pair.second + "/batch_" + std::to_string(file_cnt) + ".bin";
if (!FileUtils::FileExist(d_file)) {
all_exists = false;
break;
Expand All @@ -160,19 +177,30 @@ def get_base_cc_code() -> str:
if (cur_cnt_ >= max_size_) {
return false;
}
size_t max_num = num_datas > 0 ? static_cast<size_t>(num_datas) : tensor_info_.size();
max_num = std::min(max_num, tensor_info_.size());
size_t max_num = num_datas > 0 ? static_cast<size_t>(num_datas) : tensor_names_.size();
max_num = std::min(max_num, tensor_names_.size());
for (size_t i = 0; i < max_num; i++) {
const auto& pair = tensor_info_[i];
const auto& d_file = folder_ + "/" + pair.first + "/batch_" + std::to_string(cur_cnt_) + ".bin";
if (!FileUtils::ReadToBuffer(d_file, (char*)buffers[i], pair.second)) {
const auto& name = tensor_names_[i];
const auto& d_file =
folder_ + "/" + GetSaveName(name) + "/batch_" + std::to_string(cur_cnt_) + ".bin";
if (!FileUtils::ReadToBuffer(d_file, (char*)buffers[i], GetTensorSize(name))) {
return false;
}
}
cur_cnt_++;
return true;
}

size_t DatasetReader::GetTensorSize(const std::string& name) {
assert(tensor_sizes_.count(name));
return tensor_sizes_[name];
}

const std::string DatasetReader::GetSaveName(const std::string& name) {
assert(save_names_.count(name));
return save_names_[name];
}

} // namespace msc
} // namespace contrib
} // namespace tvm
Expand Down
166 changes: 160 additions & 6 deletions python/tvm/contrib/msc/core/ir/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ def dim_at(self, axis: Union[int, str]) -> int:
return int(self.shape[axis])
return int(_ffi_api.MSCTensorDimAt(self, axis))

def layout_of(self, axis: str) -> int:
return self.layout.index_of(axis)

def set_alias(self, alias: str):
"""Set alis for the tensor

Expand Down Expand Up @@ -162,7 +165,6 @@ def __init__(
outputs: List[MSCTensor],
weights: Dict[str, MSCTensor],
):

parents = [i[0] for i in inputs]
out_indices = [i[1] for i in inputs]
self.__init_handle_by_constructor__(
Expand Down Expand Up @@ -350,10 +352,12 @@ class WeightJoint(BaseJoint):
The optype of the node.
wtype: string
The weight type of the node.
strategy: string
The prune strategy of the node.
weight: MSCTensor
The weight of the node.
attrs: dict<string, string>
The attributes of the node.
weight: MSCTensor,
The weight of the node.
parents: list<WeightJoint>
The parents of the node.
friends: list<WeightJoint>
Expand All @@ -367,25 +371,71 @@ def __init__(
shared_ref: str,
optype: str,
wtype: str,
attrs: Dict[str, str],
strategy: str,
weight: MSCTensor,
attrs: Dict[str, str],
parents: List[BaseJoint],
friends: List[BaseJoint],
):

self.__init_handle_by_constructor__(
_ffi_api.WeightJoint,
index,
name,
shared_ref,
optype,
wtype,
attrs,
strategy,
weight,
attrs,
parents,
friends,
)

def get_attrs(self) -> Dict[str, str]:
"""Get all the attributes from node

Returns
-------
attributes: dict<str, str>
The attributes of node.
"""

return _ffi_api.WeightJointGetAttrs(self)

def get_attr(self, key: str, default: Optional[Any] = None) -> str:
"""Get the attribute of key from node

Parameters
-------
key: str
The key of the attribute.
default: Any
The default value when key is missing.

Returns
-------
attribute: str
The attributes of node.
"""

return self.get_attrs().get(key, default)

def has_attr(self, key: str) -> bool:
"""Check if key in attributes

Parameters
-------
key: str
The key of the attribute.

Returns
-------
has_attr: bool
Whether the key in the attributes.
"""

return bool(_ffi_api.WeightJointHasAttr(self, key))


class BaseGraph(Object):
"""Base class of all MSC Graphs."""
Expand Down Expand Up @@ -727,6 +777,110 @@ def __init__(
nodes,
)

def has_node(self, name: str) -> bool:
"""Check if weight node in the graph.

Parameters
----------
name: string
The name of the node.

Returns
-------
has_node: bool
Whether the node is in the graph
"""

return bool(_ffi_api.WeightGraphHasNode(self, name))

def find_node(self, name: str) -> WeightJoint:
"""Find weight node by name.

Parameters
----------
name: string
The name of the node.

Returns
-------
node: MSCJoint
The found node.
"""

return _ffi_api.WeightGraphFindNode(self, name)

def get_nodes(self) -> Iterable[WeightJoint]:
"""Get all the weight nodes in the graph.

Returns
-------
nodes: generator<WeightJoint>
The generator of nodes.
"""

for n in self.node_names:
yield self.find_node(n)

def to_json(self) -> str:
"""Dump the graph to json.

Returns
-------
graph_json: string
The graph in json format.
"""

return _ffi_api.WeightGraphToJson(self)

def inspect(self) -> dict:
"""Extract important info of the graph.

Returns
-------
graph_des: dict
The graph description in json format.
"""

graph_des = {
"nodes": {"total": 0},
}
for node in self.get_nodes():
graph_des["nodes"]["total"] += 1
if node.weight_type not in graph_des["nodes"]:
graph_des["nodes"][node.weight_type] = 1
else:
graph_des["nodes"][node.weight_type] += 1
return graph_des

@classmethod
def from_json(cls, json_str: str) -> BaseGraph:
"""Load the graph from json.

Parameters
----------
json_str: string
The file_path or json string.

Returns
-------
graph: WeightGraph
The graph.
"""

dict_obj = msc_utils.load_dict(json_str)
return _ffi_api.WeightGraphFromJson(msc_utils.dump_dict(dict_obj))

def clone(self) -> BaseGraph:
"""Clone the graph.

Returns
-------
new_graph: MSCGraph
The cloned graph.
"""

return MSCGraph.from_json(self.to_json())

def visualize(self, path: Optional[str] = None) -> str:
"""Dump the graph to prototxt format.

Expand Down
Loading
Loading