Skip to content

Commit

Permalink
[Relax] support scatter ops (#17509)
Browse files Browse the repository at this point in the history
* support scatter ops

* format fix

* format fix
  • Loading branch information
Archermmt authored Nov 11, 2024
1 parent d5b9f5c commit e3665ae
Show file tree
Hide file tree
Showing 14 changed files with 509 additions and 37 deletions.
5 changes: 4 additions & 1 deletion python/tvm/contrib/msc/core/codegen/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ def relay_to_relax(
trans_config: Optional[Dict[str, str]] = None,
build_config: Optional[Dict[str, str]] = None,
opt_config: Optional[Dict[str, str]] = None,
build_folder: msc_utils.MSCDirectory = None,
) -> tvm.IRModule:
"""Change relay IRModule to relax MSCGraph.
Expand All @@ -239,6 +240,8 @@ def relay_to_relax(
The config for build MSCGraph.
opt_config: dict
The config for optimize the relay before translate.
build_folder: MSCDirectory
The folder for saving scripts and datas.
Returns
-------
Expand All @@ -254,4 +257,4 @@ def relay_to_relax(
opt_config=opt_config,
)

return to_relax(graph, weights, codegen_config={"from_relay": True})
return to_relax(graph, weights, codegen_config={"from_relay": True}, build_folder=build_folder)
48 changes: 47 additions & 1 deletion python/tvm/contrib/msc/core/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@
import os
import shutil
import json
from typing import List, Union, Dict, Any
from typing import List, Union, Dict, Any, Tuple
import numpy as np

import tvm
from .arguments import load_dict
from .info import cast_array, is_array
from .namespace import MSCFramework


def format_datas(datas: Union[List[Any], Dict[str, Any]], names: List[str], style="dict") -> Any:
Expand Down Expand Up @@ -64,6 +65,51 @@ def format_datas(datas: Union[List[Any], Dict[str, Any]], names: List[str], styl
raise TypeError("Unexpected style " + str(style))


def random_data(
info: Union[List, Tuple, dict],
framework: str = MSCFramework.MSC,
device: str = "cpu",
max_val: int = None,
) -> Any:
"""Create random data from info
Parameters
----------
info: list| tuple| dict
The data info.
framework: str
The framework.
device: str
The device.
"""

if isinstance(info, (tuple, list)):
if len(info) == 1:
info = {"name": "data", "shape": info[0], "dtype": "float32"}
elif len(info) == 2:
info = {"name": "data", "shape": info[0], "dtype": info[1]}
elif len(info) == 3:
info = {"name": info[0], "shape": info[1], "dtype": info[2]}
else:
raise Exception("Unexpected info " + str(info))
assert isinstance(info, dict) and all(
key in info for key in ["shape", "dtype"]
), "shape and dtype should be given to create randome data"
if info["dtype"] in ("int32", "int64"):
if max_val is None:
data = np.zeros(info["shape"]).astype(info["dtype"])
else:
data = np.random.randint(0, high=max_val, size=info["shape"]).astype(info["dtype"])
elif info["dtype"] == "bool":
data = np.random.rand(*info["shape"]).astype("float32")
data = np.where(data >= 0.5, True, False)
else:
data = np.random.rand(*info["shape"]).astype(info["dtype"])
if max_val is not None:
data *= max_val
return cast_array(data, framework, device=device)


class BaseDataLoader(object):
"""Basic dataset loader for MSC
Expand Down
18 changes: 15 additions & 3 deletions python/tvm/contrib/msc/framework/torch/frontend/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from tvm.contrib.msc.core.ir.graph import MSCGraph
from tvm.contrib.msc.core.frontend import from_relax, normalize_inputs
from tvm.contrib.msc.core.codegen import relay_to_relax
from tvm.contrib.msc.core import utils as msc_utils


def set_weight_alias(graph: MSCGraph) -> MSCGraph:
Expand Down Expand Up @@ -70,6 +71,7 @@ def from_torch(
opt_config: Optional[Dict[str, str]] = None,
as_msc: bool = True,
custom_convert_map: dict = None,
build_folder: msc_utils.MSCDirectory = None,
) -> Tuple[Union[MSCGraph, tvm.IRModule], Dict[str, tvm.nd.array]]:
"""Change torch nn.Module to MSCGraph.
Expand All @@ -93,6 +95,8 @@ def from_torch(
Set to to return msc graph, otherwise relax mod
custom_convert_map: dict
The convert map for plugin
build_folder: MSCDirectory
The folder for saving scripts and datas.
Returns
-------
Expand All @@ -102,9 +106,15 @@ def from_torch(
The weights from the IRModule.
"""

# try to symbolic_trace
if via_relax:
input_info = normalize_inputs(input_info)
graph_model, params = torch.fx.symbolic_trace(model), None
try:
graph_model = torch.fx.symbolic_trace(model)
except: # pylint: disable=bare-except
via_relax = False

if via_relax:
input_info, params = normalize_inputs(input_info), None
with torch.no_grad():
relax_mod = from_fx(graph_model, input_info, custom_convert_map=custom_convert_map)
else:
Expand All @@ -122,7 +132,9 @@ def from_torch(
relay_mod, params = tvm.relay.frontend.from_pytorch(
scripted_model, shape_list, custom_convert_map=custom_convert_map
)
relax_mod = relay_to_relax(relay_mod, params, trans_config, build_config, opt_config)
relax_mod = relay_to_relax(
relay_mod, params, trans_config, build_config, opt_config, build_folder=build_folder
)
if not as_msc:
return relax_mod, params
graph, weights = from_relax(relax_mod, trans_config=trans_config, build_config=build_config)
Expand Down
3 changes: 1 addition & 2 deletions python/tvm/contrib/msc/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import json
from typing import Any, Union, List, Tuple
import traceback
import numpy as np

from tvm.contrib.msc.core.tools import get_tool_cls, BaseTool
from tvm.contrib.msc.core.utils.namespace import MSCFramework, MSCMap, MSCKey
Expand Down Expand Up @@ -678,7 +677,7 @@ def _get_loader(self, name: str = MSCStage.PREPARE) -> Any:
def get_random():
def _to_data(inp):
shape = [1 if isinstance(d, str) else d for d in inp[1]]
return np.random.rand(*shape).astype(inp[2])
return msc_utils.random_data([shape, inp[2]])

for _ in range(max_batch):
yield {i[0]: _to_data(i) for i in self._config["inputs"]}
Expand Down
32 changes: 32 additions & 0 deletions python/tvm/relax/frontend/torch/base_fx_graph_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,20 @@ def _reshape(self, node: fx.Node) -> relax.Var:
dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:]
return self.block_builder.emit(relax.op.reshape(x, dims))

def _scatter(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
if len(node.args) == 1:
dim = node.kwargs["dim"]
index = self.env[node.kwargs["index"]]
src = self.env[node.kwargs["src"]]
elif len(node.args) == 4:
dim = node.args[1]
index = self.env[node.args[2]]
src = self.env[node.args[3]]
else:
raise Exception("Unexpected args " + str(node.args))
return self.block_builder.emit(relax.op.scatter_elements(x, index, src, axis=dim))

def _split(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
split_size = node.args[1]
Expand All @@ -801,6 +815,24 @@ def _squeeze(self, node: fx.Node) -> relax.Var:
dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
return self.block_builder.emit(relax.op.squeeze(x, dim))

def _stack(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0)
in_args = args[0]
assert all(
a.struct_info.shape[axis] == in_args[0].struct_info.shape[axis] for a in in_args[1:]
), "Expect all dim at {} to be the same, get {}".format(
axis, [a.struct_info.shape for a in args]
)
cat = self.block_builder.emit(relax.op.concat(in_args, axis=axis))
s_shape = []
for idx, s in enumerate(cat.struct_info.shape):
if idx == axis:
s_shape.extend([len(in_args), in_args[0].struct_info.shape[axis]])
else:
s_shape.append(s)
return self.block_builder.emit(relax.op.reshape(cat, s_shape))

def _tile(self, node: fx.Node) -> relax.Var:
import torch # type: ignore

Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,9 +676,11 @@ def create_convert_map(
"permute": self._permute,
"repeat": self._repeat,
"reshape": self._reshape,
"scatter": self._scatter,
"size": self._size,
"split": self._split,
"squeeze": self._squeeze,
"stack": self._stack,
"tile": self._tile,
"transpose": self._transpose,
"unsqueeze": lambda node: self.block_builder.emit(
Expand Down
56 changes: 53 additions & 3 deletions src/contrib/msc/framework/torch/torch_opcode.cc
Original file line number Diff line number Diff line change
Expand Up @@ -214,14 +214,28 @@ class TorchConstantCodeGen : public TorchOpCode {

protected:
void CodeGenInit() final {
const auto& dtype = node()->OutputAt(0)->DTypeName();
const auto& ref_name = StringUtils::Replace(node()->name, ".", "_");
if (node()->HasAttr("scalar")) {
if (node()->OutputAt(0)->DTypeName() == "int32") {
if (dtype == "int32") {
stack_.assign(module_ref(), node()->GetTypeAttr<int>("scalar"));
} else if (node()->OutputAt(0)->DTypeName() == "int64") {
} else if (dtype == "int64") {
stack_.assign(module_ref(), node()->GetTypeAttr<int64_t>("scalar"));
} else if (node()->OutputAt(0)->DTypeName() == "float32") {
} else if (dtype == "float32") {
stack_.assign(module_ref(), node()->GetTypeAttr<float>("scalar"));
}
} else if (dtype == "int32") {
stack_.func_call("register_buffer", "", "self")
.call_arg(DocUtils::ToStr(ref_name))
.inplace_start("torch.IntTensor")
.call_arg(DocUtils::ToDocList(node()->OutputAt(0)->shape))
.inplace_end();
} else if (dtype == "int64") {
stack_.func_call("register_buffer", "", "self")
.call_arg(DocUtils::ToStr(ref_name))
.inplace_start("torch.LongTensor")
.call_arg(DocUtils::ToDocList(node()->OutputAt(0)->shape))
.inplace_end();
} else {
stack_.func_call("torch.Tensor", "data")
.call_arg(DocUtils::ToDocList(node()->OutputAt(0)->shape))
Expand Down Expand Up @@ -565,6 +579,39 @@ class TorchSimpleCodeGen : public TorchOpCode {
TORCH_OP_CODEGEN_METHODS(TorchSimpleCodeGen);
};

class TorchScatterElementsCodeGen : public TorchOpCode {
TORCH_OP_CODEGEN_METHODS(TorchScatterElementsCodeGen)

protected:
void CodeGenForward() final {
if (node()->InputAt(1)->DTypeName() == "int32") {
stack_.func_call("to", IdxInput(1), IdxInput(1)).call_arg("torch.int64");
}
stack_.op_call()
.op_input_arg()
.op_arg<int>("axis", "dim")
.op_input_arg(1, "index")
.op_input_arg(2, "src");
}
};

class TorchScatterNDCodeGen : public TorchOpCode {
TORCH_OP_CODEGEN_METHODS(TorchScatterNDCodeGen)

protected:
void CodeGenForward() final {
if (node()->InputAt(1)->DTypeName() == "int32") {
stack_.func_call("to", IdxInput(1), IdxInput(1)).call_arg("torch.int64");
}
// relax add extra dim for indices
if (node()->InputAt(1)->Ndim() == node()->OutputAt(0)->Ndim()) {
stack_.func_call("squeeze", IdxInput(1), IdxInput(1)).call_arg(-1);
}
stack_.assign(DocUtils::ToIndex(IdxInput(0), IdxInput(1)), IdxInput(2))
.assign(IdxNode(), IdxInput(0));
}
};

class TorchSplitCodeGen : public TorchOpCode {
TORCH_OP_CODEGEN_METHODS(TorchSplitCodeGen)

Expand Down Expand Up @@ -719,6 +766,9 @@ const std::shared_ptr<std::unordered_map<String, std::shared_ptr<TorchOpCode>>>
map->emplace("permute_dims", std::make_shared<TorchPermuteDimsCodeGen>("", "torch.permute"));
map->emplace("repeat", std::make_shared<TorchRepeatCodeGen>("", "repeat"));
map->emplace("reshape", std::make_shared<TorchReshapeCodeGen>("", "torch.reshape"));
map->emplace("scatter_elements",
std::make_shared<TorchScatterElementsCodeGen>("", "torch.scatter"));
map->emplace("scatter_nd", std::make_shared<TorchScatterNDCodeGen>("", ""));
map->emplace("split", std::make_shared<TorchSplitCodeGen>("", "torch.split"));
map->emplace("strided_slice", std::make_shared<TorchStridedSliceCodeGen>("", ""));

Expand Down
46 changes: 46 additions & 0 deletions src/contrib/msc/framework/tvm/relax_opcode.cc
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,34 @@ class RelaxReshapeCodeGen : public RelaxOpCode {
}
};

class RelaxScatterElementsCodeGen : public RelaxOpCode {
RELAX_OP_CODEGEN_METHODS(RelaxScatterElementsCodeGen)

protected:
void CodeGenBuild() final { stack_.op_call().op_inputs_arg(false).op_arg<int>("axis"); }
};

class RelaxScatterNDCodeGen : public RelaxOpCode {
RELAX_OP_CODEGEN_METHODS(RelaxScatterNDCodeGen)

protected:
void CodeGenBuild() final {
if (config()->from_relay) {
size_t ndim = node()->InputAt(1)->Ndim();
std::vector<size_t> axes;
axes.push_back(ndim - 1);
for (size_t i = 0; i < ndim - 1; i++) {
axes.push_back(i);
}
stack_.func_call("relax.op.permute_dims", IdxInput(1))
.call_arg(IdxInput(1))
.call_arg(DocUtils::ToList(axes));
BuilderEmit(IdxInput(1), "permute_" + std::to_string(node()->index));
}
stack_.op_call().op_inputs_arg(false).op_str_arg("mode", "reduction");
}
};

class RelaxResize2dCodeGen : public RelaxOpCode {
RELAX_OP_CODEGEN_METHODS(RelaxResize2dCodeGen)

Expand Down Expand Up @@ -626,6 +654,20 @@ class RelaxSplitCodeGen : public RelaxOpCode {
}
};

class RelaxStackCodeGen : public RelaxOpCode {
RELAX_OP_CODEGEN_METHODS(RelaxStackCodeGen)

protected:
void CodeGenBuild() final {
stack_.op_call().op_inputs_arg().op_arg<int>("axis");
BuilderEmit(IdxNode(), "cat_" + std::to_string(node()->index));
const auto& out_shape = GetPrims(node()->OutputAt(0));
stack_.func_call("relax.op.reshape", IdxNode())
.call_arg(IdxNode())
.call_arg(DocUtils::ToList(out_shape), "shape");
}
};

class RelaxTakeCodeGen : public RelaxOpCode {
RELAX_OP_CODEGEN_METHODS(RelaxTakeCodeGen)

Expand Down Expand Up @@ -763,7 +805,11 @@ const std::shared_ptr<std::unordered_map<String, std::shared_ptr<RelaxOpCode>>>
map->emplace("permute_dims", std::make_shared<RelaxPermuteDimsCodeGen>("relax.op.permute_dims"));
map->emplace("repeat", std::make_shared<RelaxRepeatCodeGen>("relax.op.repeat"));
map->emplace("reshape", std::make_shared<RelaxReshapeCodeGen>("relax.op.reshape"));
map->emplace("scatter_elements",
std::make_shared<RelaxScatterElementsCodeGen>("relax.op.scatter_elements"));
map->emplace("scatter_nd", std::make_shared<RelaxScatterNDCodeGen>("relax.op.scatter_nd"));
map->emplace("split", std::make_shared<RelaxSplitCodeGen>("relax.op.split"));
map->emplace("stack", std::make_shared<RelaxStackCodeGen>("relax.op.concat"));
map->emplace("strided_slice",
std::make_shared<RelaxStridedSliceCodeGen>("relax.op.strided_slice"));
map->emplace("take", std::make_shared<RelaxTakeCodeGen>("relax.op.take"));
Expand Down
Loading

0 comments on commit e3665ae

Please sign in to comment.