Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

[retiarii] graph gen and code gen improvement #3330

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 33 additions & 4 deletions nni/retiarii/codegen/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,25 @@ def _sorted_incoming_edges(node: Node) -> List[Edge]:


def _format_inputs(node: Node) -> List[str]:
"""
Format the inputs of a given node

Parameters
----------
node : Node
a graph node, get and format its inputs

Returns
-------
list
the list of input names
list
the list of input values, if an input is simple type, record its value,
otherwise the value is None
"""
edges = _sorted_incoming_edges(node)
inputs = []
inputs_value = []
for edge in edges:
if edge.head.name == '_inputs':
assert isinstance(edge.head_slot, int)
Expand All @@ -44,14 +61,21 @@ def _format_inputs(node: Node) -> List[str]:
else:
# when input has no name, e.g., forward(*_inputs)
inputs.append('_inputs[{}]'.format(edge.head_slot))
inputs_value.append(None)
else:
if edge.head_slot is None:
# when the input comes from a single-output operator
inputs.append('{}'.format(edge.head.name))
if edge.head.operation.type == 'prim::Constant' and \
'value' in edge.head.operation.parameters:
inputs_value.append(edge.head.operation.parameters['value'])
else:
inputs_value.append(None)
else:
# when the input comes from a multi-output operator: needs to know which one it comes from
inputs.append('{}[{}]'.format(edge.head.name, edge.head_slot))
return inputs
inputs_value.append(None)
return inputs, inputs_value


def _remove_prefix(names, graph_name):
Expand Down Expand Up @@ -80,6 +104,8 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str
node_codes = []
for node in nodes:
if node.operation:
if node.operation.type == 'shared':
continue
pkg_name = node.operation.get_import_pkg()
if pkg_name is not None:
import_pkgs.add(pkg_name)
Expand All @@ -101,12 +127,15 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str
sorted_nodes = graph.topo_sort()
for node in sorted_nodes:
if node.operation:
inputs = _format_inputs(node)
inputs, inputs_value = _format_inputs(node)
inputs = _remove_prefix(inputs, graph_name)
node_name = _remove_prefix(node.name, graph_name)
edge_codes.append(node.operation.to_forward_code(node_name, node_name, inputs))
submodule_name = node_name
if node.operation.type == 'shared':
submodule_name = _remove_prefix(node.operation.parameters['reference'], graph_name)
edge_codes.append(node.operation.to_forward_code(submodule_name, node_name, inputs, inputs_value))

output_names = _format_inputs(graph.output_node)
output_names, _ = _format_inputs(graph.output_node)
output_names = _remove_prefix(output_names, graph_name)
if not output_names:
raise RuntimeError('"forward" function should have return value(s): {}, {}, {}'.format(output_names, graph_name, graph.output_node))
Expand Down
326 changes: 215 additions & 111 deletions nni/retiarii/converter/graph_gen.py

Large diffs are not rendered by default.

43 changes: 42 additions & 1 deletion nni/retiarii/converter/op_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ class OpTypeName(str, Enum):
Attr = 'Attr'
Constant = 'Constant'
ListConstruct = 'ListConstruct'
ListUnpack = 'ListUnpack'
TupleConstruct = 'TupleConstruct'
TupleUnpack = 'TupleUnpack'
LayerChoice = 'LayerChoice'
InputChoice = 'InputChoice'
ValueChoice = 'ValueChoice'
Expand All @@ -36,7 +38,46 @@ class OpTypeName(str, Enum):
'aten::empty': 'Empty',
'aten::zeros': 'Zeros',
'aten::chunk': 'Chunk',
'aten::add_': 'Add_' # %out.3 : Tensor = aten::add_(%out.1, %connection.1, %4)
'aten::add_': 'Add_', # %out.3 : Tensor = aten::add_(%out.1, %connection.1, %4)
'aten::flatten': 'Flatten',
'aten::sigmoid': 'Sigmoid',
'aten::detach': 'Detach',
'aten::le': 'Le',
'aten::new_zeros': 'NewZeros',
'aten::__not__': 'not',
'aten::transpose': 'Transpose',
'aten::contiguous': 'Contiguous',
'aten::new_full': 'NewFull',
'aten::new_empty': 'NewEmpty',
'aten::new_zeros': 'NewZeros',
'aten::tensor': 'Tensor',
'aten::abs': 'Abs',
'aten::abs_': 'Abs_',
'aten::acos': 'Acos',
'aten::acos_': 'Acos_',
'aten::asin': 'Asin',
'aten::atan': 'Atan',
'aten::atan2': 'Atan2',
'aten::addbmm': 'Addbmm',
'aten::baddbmm': 'Baddbmm',
'aten::addcdiv': 'Addcdiv',
'aten::addcmul': 'Addcmul',
'aten::addmm': 'Addmm',
'aten::addmv': 'Addmv',
'aten::addr': 'Addr',
'aten::bmm': 'Bmm',
'aten::allclose': 'Allclose',
'aten::angle': 'Angle',
'aten::argmax': 'Argmax',
'aten::argmin': 'Argmin',
'aten::argsort': 'Argsort',
'aten::bernoulli': 'Bernoulli',
'aten::bincount': 'Bincount',
'aten::bitwise_not': 'BitwiseNot',
'aten::bitwise_and': 'BitwiseAnd',
'aten::bitwise_or': 'BitwiseOr',
'aten::bitwise_xor': 'BitwiseXor',
'prim::is_cuda': 'IsCuda'
}

BasicOpsTF = {}
2 changes: 1 addition & 1 deletion nni/retiarii/nn/pytorch/mutator.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def mutate(self, model):
chosen = self.choice(self.candidates)
for node in self.nodes:
target = model.get_node_by_name(node.name)
target.update_operation('prim::Constant', {'value': chosen})
target.update_operation('prim::Constant', {'type': type(chosen).__name__, 'value': chosen})


def process_inline_mutation(model: Model) -> Optional[List[Mutator]]:
Expand Down
177 changes: 169 additions & 8 deletions nni/retiarii/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,30 @@

__all__ = ['Operation', 'Cell']

mem_format = [
'torch.contiguous_format', # 0
'torch.preserve_format', # 1
'torch.channels_last', # 2
]

# this snippet is copied from torch/onnx/symbolic_helper.py,
# the original definition is in c10/core/ScalarType.h
# This indicates each scalar type's corresponding
scalar_type_to_pytorch_type = [
'torch.uint8', # 0
'torch.int8', # 1
'torch.short', # 2
'torch.int', # 3
'torch.int64', # 4
'torch.half', # 5
'torch.float', # 6
'torch.double', # 7
'torch.complex32', # 8
'torch.complex64', # 9
'torch.complex128', # 10
'torch.bool', # 11
]


def _convert_name(name: str) -> str:
"""
Expand Down Expand Up @@ -106,25 +130,56 @@ def to_init_code(self, field: str) -> str:
return f'self.{field} = {self._to_class_name()}({kw_params})'
return None

def to_forward_code(self, field: str, output: str, inputs: List[str]) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
"""
Parameters
----------
field : str
the name of member submodule
output : str
the output name (lvalue) of this line of code
inputs : List[str]
variables used in this line of code
inputs_value : List[Any]
some variables are actually constant, their real values are recorded in ```inputs_value```.
if not constant, we simply put None at the corresponding index
"""
from .converter.op_types import OpTypeName
if self._to_class_name() is not None:
return f'{output} = self.{field}({", ".join(inputs)})'
elif self.type == 'shared':
return f'{output} = self.{field}({", ".join(inputs)})'
elif self.type.startswith('Function.'):
func_name = self.type[len('Function.'):]
return f'{output} = F.{func_name}({", ".join(inputs)})'
elif self.type == 'prim::Constant':
if self.parameters:
# TODO: refactor this part, maybe we can remove the code gen of prim::Constant
# TODO: deal with all the types
if self.parameters['type'] == 'None':
return f'{output} = None'
elif self.parameters['type'] in ('int', 'float', 'bool'):
return f'{output} = {self.parameters["value"]}'
elif self.parameters['type'] == 'Device':
value = self.parameters['value']
return f'{output} = torch.device("{value}")'
else:
value = None
return f'{output} = {value}'
raise RuntimeError(f'unsupported type of prim::Constant: {self.parameters["type"]}')
print('zql value: ', value, type(value))
elif self.type == 'prim::ListConstruct':
return f'{output} = [{", ".join(inputs)}]'
elif self.type == 'prim::TupleConstruct':
return f'{output} = ({", ".join(inputs)})'
elif self.type == 'prim::TupleUnpack':
# have single output here, because the following code uses index to access the unpacked values
assert len(inputs) == 1
return f'{output} = {inputs[0]}'
elif self.type == 'prim::GetAttr':
return f"{output} = {self.parameters['input']}.{self.parameters['name']}"
if self.parameters['value'] is not None:
return f"{output} = {self.parameters['value']}"
else:
return f"{output} = {self.parameters['input']}.{self.parameters['name']}"
elif self.type == 'prim::is_cuda':
return f'{output} = {inputs[0]}.is_cuda'
elif self.type == 'aten::mean':
return f'{output} = torch.mean({inputs[0]}, {", ".join(inputs[1:-1])}, out={inputs[-1]})'
elif self.type == 'aten::__getitem__':
Expand All @@ -137,7 +192,19 @@ def to_forward_code(self, field: str, output: str, inputs: List[str]) -> str:
assert len(inputs) == 2
return f'{output} = torch.cat({inputs[0]}, dim={inputs[1]})'
elif self.type == 'aten::add':
return f'{output} = ' + ' + '.join(inputs)
# TODO: verify the correctness
#return f'{output} = ' + ' + '.join(inputs)
if len(inputs) == 2:
return f'{output} = {inputs[0]}.add({inputs[1]})'
else:
assert len(inputs) == 3
return f'{output} = {inputs[0]}.add({inputs[1]}, alpha={inputs[2]})'
elif self.type == 'aten::add_':
if len(inputs) == 2:
return f'{output} = {inputs[0]}.add_({inputs[1]})'
else:
assert len(inputs) == 3
return f'{output} = {inputs[0]}.add_({inputs[1]}, alpha={inputs[2]})'
elif self.type == OpTypeName.MergedSlice:
assert (len(inputs) - 1) % 4 == 0
slices = []
Expand All @@ -147,8 +214,10 @@ def to_forward_code(self, field: str, output: str, inputs: List[str]) -> str:
slice_str = ','.join(slices)
return f'{output} = {inputs[0]}[{slice_str}]'
elif self.type == 'aten::size':
assert len(inputs) == 2
return f'{output} = {inputs[0]}.size({inputs[1]})'
if len(inputs) == 2:
return f'{output} = {inputs[0]}.size({inputs[1]})'
else:
return f'{output} = {inputs[0]}.size()'
elif self.type == 'aten::view':
assert len(inputs) == 2
return f'{output} = {inputs[0]}.view({inputs[1]})'
Expand All @@ -159,6 +228,98 @@ def to_forward_code(self, field: str, output: str, inputs: List[str]) -> str:
raise RuntimeError('not supposed to have aten::slice operation')
elif self.type == 'aten::Bool':
return f'{output} = bool({inputs[0]})'
elif self.type == 'aten::flatten':
return f'{output} = torch.flatten({inputs[0]}, {inputs[1]}, {inputs[2]})'
elif self.type == 'aten::sigmoid':
assert len(inputs) == 1
return f'{output} = torch.sigmoid({inputs[0]})'
elif self.type == 'aten::__not__':
return f'{output} = not {inputs[0]}'
elif self.type == 'aten::transpose':
return f'{output} = {inputs[0]}.transpose({inputs[1]}, {inputs[2]})'
elif self.type == 'aten::contiguous':
# defined in pytorch/c10/core/MemoryFormat.h
assert inputs_value[1] in [0, 1, 2]
return f'{output} = {inputs[0]}.contiguous(memory_format={mem_format[inputs_value[1]]})'
elif self.type == 'aten::detach':
return f'{output} = {inputs[0]}.detach()'
elif self.type == 'aten::new_full':
device_str = f', device=torch.device({inputs[5]})' if inputs_value[5] is not None else ''
dtype_str = f', dtype={scalar_type_to_pytorch_type[inputs_value[3]]}' if inputs_value[3] is not None else ''
return f'{output} = {inputs[0]}.new_full({inputs[1]}, {inputs[2]}{dtype_str}{device_str})'
elif self.type == 'aten::new_empty':
device_str = f', device=torch.device({inputs[4]})' if inputs_value[4] is not None else ''
dtype_str = f', dtype={scalar_type_to_pytorch_type[inputs_value[2]]}' if inputs_value[2] is not None else ''
return f'{output} = {inputs[0]}.new_empty({inputs[1]}{dtype_str}{device_str})'
elif self.type == 'aten::new_zeros':
# in pytorch: new_zeros(size, dtype=None, device=None, requires_grad=False) → Tensor
# in aten: - func: new_zeros(Tensor self, int[] size, *, ScalarType? dtype=None,
# Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
# TODO: check requires_grad when it is true!!!
device_str = f', device=torch.device({inputs[4]})' if inputs_value[4] is not None else ''
dtype_str = f', dtype={scalar_type_to_pytorch_type[inputs_value[2]]}' if inputs_value[2] is not None else ''
return f'{output} = {inputs[0]}.new_zeros({inputs[1]}{dtype_str}{device_str})'
elif self.type == 'aten::tensor':
device_str = f', device=torch.device({inputs[2]})' if inputs_value[2] is not None else ''
dtype_str = f', dtype={scalar_type_to_pytorch_type[inputs_value[1]]}' if inputs_value[1] is not None else ''
req_grad_str = f', requires_grad={inputs[3]}' if inputs_value[3] else ''
return f'{output} = torch.tensor({inputs[0]}{dtype_str}{device_str}{req_grad_str})'
elif self.type == 'aten::abs':
return f'{output} = {inputs[0]}.abs()'
elif self.type == 'aten::abs_':
return f'{output} = {inputs[0]}.abs_()'
elif self.type == 'aten::acos':
return f'{output} = {inputs[0]}.acos()'
elif self.type == 'aten::acos_':
return f'{output} = {inputs[0]}.acos_()'
elif self.type == 'aten::asin':
return f'{output} = {inputs[0]}.asin()'
elif self.type == 'aten::atan':
return f'{output} = {inputs[0]}.atan()'
elif self.type == 'aten::atan2':
return f'{output} = {inputs[0]}.atan2({inputs[1]})'
elif self.type == 'aten::addbmm':
return f'{output} = {inputs[0]}.addbmm({inputs[1]}, {inputs[2]}, beta={inputs[3]}, alpha={inputs[4]})'
elif self.type == 'aten::baddbmm':
return f'{output} = {inputs[0]}.baddbmm({inputs[1]}, {inputs[2]}, beta={inputs[3]}, alpha={inputs[4]})'
elif self.type == 'aten::addcdiv':
return f'{output} = {inputs[0]}.addcdiv({inputs[1]}, {inputs[2]}, value={inputs[3]})'
elif self.type == 'aten::addcmul':
return f'{output} = {inputs[0]}.addcmul({inputs[1]}, {inputs[2]}, value={inputs[3]})'
elif self.type == 'aten::addmm':
return f'{output} = {inputs[0]}.addmm({inputs[1]}, {inputs[2]}, beta={inputs[3]}, alpha={inputs[4]})'
elif self.type == 'aten::addmv':
return f'{output} = {inputs[0]}.addmv({inputs[1]}, {inputs[2]}, beta={inputs[3]}, alpha={inputs[4]})'
elif self.type == 'aten::bmm':
return f'{output} = {inputs[0]}.bmm({inputs[1]})'
elif self.type == 'aten::addr':
return f'{output} = {inputs[0]}.addr({inputs[1]}, {inputs[2]}, beta={inputs[3]}, alpha={inputs[4]})'
elif self.type == 'aten::allclose':
return f'{output} = {inputs[0]}.allclose({inputs[1]}, rtol={inputs[2]}, atol={inputs[3]}, equal_nan={inputs[4]})'
elif self.type == 'aten::angle':
return f'{output} = {inputs[0]}.angle()'
elif self.type == 'aten::argmax':
return f'{output} = {inputs[0]}.argmax(dim={inputs[1]}, keepdim={inputs[2]})'
elif self.type == 'aten::argmin':
return f'{output} = {inputs[0]}.argmin(dim={inputs[1]}, keepdim={inputs[2]})'
elif self.type == 'aten::argsort':
return f'{output} = {inputs[0]}.argsort(dim={inputs[1]}, descending={inputs[2]})'
elif self.type == 'aten::bernoulli':
assert inputs_value[1] is None
return f'{output} = {inputs[0]}.bernoulli()'
elif self.type == 'aten::bincount':
return f'{output} = {inputs[0]}.bincount(weights={inputs[1]}, minlength={inputs[2]})'
elif self.type == 'aten::bitwise_not':
return f'{output} = {inputs[0]}.bitwise_not()'
elif self.type == 'aten::bitwise_and':
return f'{output} = {inputs[0]}.bitwise_and({inputs[1]})'
elif self.type == 'aten::bitwise_or':
return f'{output} = {inputs[0]}.bitwise_or({inputs[1]})'
elif self.type == 'aten::bitwise_xor':
return f'{output} = {inputs[0]}.bitwise_xor({inputs[1]})'
elif self.type == 'noop_identity':
# this operator type is added by us
return f'{output} = {", ".join(inputs)}'
else:
raise RuntimeError(f'unsupported operation type: {self.type} ? {self._to_class_name()}')

Expand Down
8 changes: 6 additions & 2 deletions nni/retiarii/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,14 +112,18 @@ def blackbox_module(cls):
assert (inspect.getmodule(frm[0]) is not None), ('unable to locate the definition of the given black box module, '
'please define it explicitly in a .py file.')
module_name = inspect.getmodule(frm[0]).__name__

if module_name == '__main__':
main_file_path = Path(inspect.getsourcefile(frm[0]))
if main_file_path.parents[0] != Path('.'):
raise RuntimeError(f'you are using "{main_file_path}" to launch your experiment, '
f'please launch the experiment under the directory where "{main_file_path.name}" is located.')
module_name = main_file_path.stem

# NOTE: this is hacky. As torchscript retrieves LSTM's source code to do something.
# to make LSTM's source code can be found, we should assign original LSTM's __module__ to
# the wrapped LSTM's __module__
# TODO: find out all the modules that have the same requirement as LSTM
if f'{cls.__module__}.{cls.__name__}' == 'torch.nn.modules.rnn.LSTM':
module_name = cls.__module__
return _blackbox_cls(cls, module_name, 'args')


Expand Down
Loading