Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

[Numpy Refactor] [Model Deployment] Use TVM to accelerate model inference + deployment #1244

Open
sxjscience opened this issue Jun 10, 2020 · 7 comments
Labels
enhancement New feature or request numpyrefactor

Comments

@sxjscience
Copy link
Member

sxjscience commented Jun 10, 2020

Currently, we do have Relay VM support of the NDArray version of MXNet: https://github.com/apache/incubator-tvm/blob/master/python/tvm/relay/frontend/mxnet.py

However, we still miss the numpy array support in Relay frontend converter and we should first add the numpy support.

I checked the workloads of BERT + ALBERT + ELECTRA + MobileBERT + RoBERTa (only backbone) and these are the ops used:

_npi_transpose
FullyConnected
_npi_pad
_contrib_arange_like
expand_dims
null
_npi_multiply_scalar
_npi_true_divide_scalar
Activation
_npi_add
_npi_concatenate
_npi_multiply
SwapAxis
_np_copy
_npi_tanh
take
softmax
_npi_power_scalar
_npi_less
SequenceMask
LayerNorm
Dropout
erf
Embedding
_npi_add_scalar
_npx_reshape
_split_v2
_npi_where_rscalar
slice
Cast
batch_dot

After some investigation, the following are the the ops that need to be converted:

_npi_transpose
_npi_pad
_contrib_arange_like
null
_npi_multiply_scalar
_npi_true_divide_scalar
_npi_add
_npi_concatenate
_npi_multiply
_np_copy
_npi_tanh
_npi_power_scalar
_npi_less
_npi_add_scalar
_npx_reshape
_split_v2
_npi_where_rscalar

We will revise the relay runtime in TVM accordingly.

Code for getting the missing ops for relay:

import json
import mxnet as mx
from gluonnlp.models import list_backbone_names, get_backbone

mx.npx.set_np()
batch_size = 1
sequence_length = 32
all_possible_ops = []
for name in list_backbone_names():
    model_cls, cfg, tokenizer, local_params_path, others = get_backbone(model_name=name)
    net = model_cls.from_cfg(cfg)
    net.initialize()
    net.hybridize()
    print('Save the architecture of {} to {}.json'.format(name, name))
    inputs = mx.np.random.randint(0, 10, (batch_size, sequence_length))
    token_types = mx.np.random.randint(0, 2, (batch_size, sequence_length))
    valid_length = mx.np.random.randint(1, 10, (batch_size,))
    if 'roberta' in name or 'xlmr' in name:
        out = net(inputs, valid_length)
    else:
        out = net(inputs, token_types, valid_length)
    sym = net._cached_graph[1]
    sym.save('{}.json'.format(name), remove_amp_cast=True)
    all_ops = set()
    with open('{}.json'.format(name), 'r') as f:
        sym_info = json.load(f)
        for ele in sym_info['nodes']:
            all_ops.add(ele['op'])
    with open('{}_all_ops.json'.format(name), 'w') as f:
        json.dump(list(all_ops), f)
    all_possible_ops.extend(list(all_ops))

with open('all_possible_ops.json', 'w') as f:
    json.dump(list(set(all_possible_ops)), f)


from tvm.relay.frontend.mxnet import _convert_map

for op in all_possible_ops:
    if op not in _convert_map:
        print(op)
@sxjscience sxjscience added enhancement New feature or request numpyrefactor labels Jun 10, 2020
@sxjscience sxjscience changed the title [Model Deployment] Use TVM to accelerate model inference + deployment [Numpy Refactor] [Model Deployment] Use TVM to accelerate model inference + deployment Jun 10, 2020
@sxjscience
Copy link
Member Author

@yzhliu The numpy version has been merged and I've attached the script for generating the missing ops in relay converter.

@carter54
Copy link

any tutorial to deploy gluon gpt-2 model with TVM?

@sxjscience
Copy link
Member Author

sxjscience commented Jun 23, 2020

@carter54 Thanks for your interest. We will add the support + tutorial later and it's in the roadmap.

@carter54
Copy link

Thanks @sxjscience. Looking forward to try it~

@sxjscience
Copy link
Member Author

@sxjscience
Copy link
Member Author

sxjscience commented Nov 2, 2020

@carter54 In case you'd like to try out TVM now, you may try to use docker (or compile TVM with cublas + blas enabled as in https://github.com/dmlc/gluon-nlp/blob/master/tools/docker/install/install_tvm_cpu.sh) and refer to our test cases here:

def test_tvm_integration(model_name, batch_size, seq_length, layout, ctx):
tvm = try_import_tvm()
from tvm import relay
from tvm.contrib import graph_runtime
tvm_recommended_flags = get_ec2_tvm_flags()
if ctx.device_type == 'gpu':
flags = tvm_recommended_flags['g4']
elif ctx.device_type == 'cpu':
flags = tvm_recommended_flags['c4']
if model_name != 'google_albert_base_v2':
# Skip all other tests
return
else:
raise NotImplementedError
with tempfile.TemporaryDirectory() as root, ctx:
model_cls, cfg, tokenizer, backbone_param_path, _ = get_backbone(model_name, root=root)
cfg.defrost()
cfg.MODEL.layout = layout
cfg.freeze()
model = model_cls.from_cfg(cfg)
model.load_parameters(backbone_param_path)
model.hybridize()
if layout == 'NT':
token_ids = mx.np.random.randint(0, cfg.MODEL.vocab_size, (batch_size, seq_length),
dtype=np.int32)
token_types = mx.np.random.randint(0, 2, (batch_size, seq_length), dtype=np.int32)
valid_length = mx.np.random.randint(seq_length // 2, seq_length, (batch_size,),
dtype=np.int32)
else:
token_ids = mx.np.random.randint(0, cfg.MODEL.vocab_size, (seq_length, batch_size),
dtype=np.int32)
token_types = mx.np.random.randint(0, 2, (seq_length, batch_size), dtype=np.int32)
valid_length = mx.np.random.randint(seq_length // 2, seq_length, (batch_size,),
dtype=np.int32)
if 'bart' in model_name:
mx_out = model(token_ids, valid_length, token_ids, valid_length)
shape_dict = {
'data0': token_ids.shape,
'data1': valid_length.shape,
'data2': token_ids.shape,
'data3': valid_length.shape,
}
dtype_dict = {
'data0': token_ids.dtype.name,
'data1': valid_length.dtype.name,
'data2': token_ids.dtype.name,
'data3': valid_length.dtype.name,
}
elif 'roberta' in model_name or 'xlmr' in model_name:
mx_out = model(token_ids, valid_length)
shape_dict = {
'data0': token_ids.shape,
'data1': valid_length.shape,
}
dtype_dict = {
'data0': token_ids.dtype.name,
'data1': valid_length.dtype.name,
}
else:
mx_out = model(token_ids, token_types, valid_length)
shape_dict = {
'data0': token_ids.shape,
'data1': token_types.shape,
'data2': valid_length.shape
}
dtype_dict = {
'data0': token_ids.dtype.name,
'data1': token_types.dtype.name,
'data2': valid_length.dtype.name
}
sym = model._cached_graph[1]
params = {}
for k, v in model.collect_params().items():
params[v._var_name] = tvm.nd.array(v.data().asnumpy())
mod, params = relay.frontend.from_mxnet(sym, shape=shape_dict, dtype=dtype_dict, arg_params=params)
target = flags['target']
use_gpu = flags['use_gpu']
opt_level = flags['opt_level']
required_pass = flags['required_pass']
with tvm.transform.PassContext(opt_level=opt_level, required_pass=required_pass):
lib = relay.build(mod, target, params=params)
if use_gpu:
ctx = tvm.gpu()
else:
ctx = tvm.cpu()
rt = graph_runtime.GraphModule(lib["default"](ctx))
if 'bart' in model_name:
rt.set_input(data0=token_ids, data1=valid_length, data2=token_ids, data3=valid_length)
elif 'roberta' in model_name:
rt.set_input(data0=token_ids, data1=valid_length)
else:
rt.set_input(data0=token_ids, data1=token_types, data2=valid_length)
rt.run()
for i in range(rt.get_num_outputs()):
out = rt.get_output(i)
if rt.get_num_outputs() == 1:
mx_out_gt = mx_out.asnumpy()
else:
mx_out_gt = mx_out[i].asnumpy()
npt.assert_allclose(out.asnumpy(), mx_out_gt, rtol=1e-3, atol=1e-1)

We are currently adding a tutorial about how to convert GluonNLP backbones to TVM. You can also wait for our official tutorial.

@carter54
Copy link

@sxjscience Thanks a lot! I will have a try

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
enhancement New feature or request numpyrefactor
Projects
None yet
Development

No branches or pull requests

2 participants