Skip to content

Commit

Permalink
Change TOPI ops to use C++ implementation where applicable (apache#357)
Browse files Browse the repository at this point in the history
* Updated TVM version. Implemented fix for nnvm_compiler crash on exit on windows. Changed TOPI ops from using python to using C++ where applicable.

* Fix lint

* Fix lint

* Fix macro

* Fix reshape

* Update TVM to fix test fails
  • Loading branch information
alex-weaver authored and tqchen committed Feb 6, 2018
1 parent 6004907 commit 96b7934
Show file tree
Hide file tree
Showing 17 changed files with 422 additions and 170 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ include_directories(BEFORE "include")
include_directories("tvm/include")
include_directories("tvm/dlpack/include")
include_directories("tvm/HalideIR/src")
include_directories("tvm/topi/include")

set(NNVM_LINKER_LIBS "")
set(NNVM_COMPILER_LINKER_LIBS "")
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ include $(config)

export LDFLAGS = -pthread -lm
export CFLAGS = -std=c++11 -Wall -O2 -Iinclude -fPIC
CFLAGS += -Itvm/include -Itvm/dlpack/include -Itvm/HalideIR/src
CFLAGS += -Itvm/include -Itvm/dlpack/include -Itvm/HalideIR/src -Itvm/topi/include

ifdef DMLC_CORE_PATH
CFLAGS += -I$(DMLC_CORE_PATH)/include
Expand Down
33 changes: 33 additions & 0 deletions include/nnvm/compiler/util.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*!
* Copyright (c) 2016 by Contributors
* \file util.h
* \brief Utility functions for nnvm compiler
*/
#ifndef NNVM_COMPILER_UTIL_H_
#define NNVM_COMPILER_UTIL_H_

#include <tvm/expr.h>
#include <nnvm/tuple.h>

namespace nnvm {
namespace compiler {

/*
* \brief Helper function to convert TShape to TVM array. Useful for
* passing data from NNVM param structures to TOPI ops.
*
* \param shape The shape to convert
*
* \return An Array of Expr, where each element is a constant int32
*/
inline tvm::Array<tvm::Expr> ShapeToArray(TShape shape) {
tvm::Array<tvm::Expr> result;
for (auto i : shape) {
result.push_back(tvm::make_const(tvm::Int(32), i));
}
return result;
}

} // namespace compiler
} // namespace nnvm
#endif // NNVM_COMPILER_UTIL_H_
2 changes: 1 addition & 1 deletion include/nnvm/symbolic.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ namespace nnvm {
* symbol is the final operation of a graph and thus including all the information
* required (the graph) to evaluate its output value.
*/
class Symbol {
class NNVM_DLL Symbol {
public:
/*! \brief option passed to ListAttr */
enum ListAttrOption {
Expand Down
85 changes: 0 additions & 85 deletions python/nnvm/top/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,59 +10,26 @@
from .registry import OpPattern

# relu
@reg.register_compute("relu")
def compute_relu(attrs, inputs, _):
"""Compute definition of relu"""
return topi.nn.relu(inputs[0])

reg.register_schedule("relu", _fschedule_broadcast)
reg.register_pattern("relu", OpPattern.ELEMWISE)


# leaky_relu
@reg.register_compute("leaky_relu")
def compute_leaky_relu(attrs, inputs, _):
"""Compute definition of relu"""
return topi.nn.leaky_relu(inputs[0], attrs.get_float("alpha"))

reg.register_schedule("leaky_relu", _fschedule_broadcast)
reg.register_pattern("leaky_relu", OpPattern.ELEMWISE)


# flatten
@reg.register_compute("flatten")
def compute_flatten(attrs, inputs, _):
"""Compute definition of flatten"""
return topi.nn.flatten(inputs[0])

reg.register_schedule("flatten", _fschedule_broadcast)
reg.register_pattern("flatten", OpPattern.INJECTIVE)


# pad
@reg.register_compute("pad")
def compute_pad(attrs, inputs, _):
"""Compute definition of pad"""
pad_width = attrs.get_int_pair_tuple('pad_width')
assert len(pad_width) == len(inputs[0].shape) and \
len(pad_width[0]) == 2, "illegal pad_width"
pad_before = [x[0] for x in pad_width]
pad_after = [x[1] for x in pad_width]
pad_value = attrs.get_int('pad_value')
return topi.nn.pad(inputs[0], pad_before, pad_after, pad_value)

reg.register_schedule("pad", _fschedule_broadcast)
reg.register_pattern("pad", OpPattern.INJECTIVE)


# softmax
@reg.register_compute("softmax")
def compute_softmax(attrs, inputs, _):
"""Compute definition of softmax"""
axis = attrs.get_int("axis")
assert axis == -1, "only support axis == -1 for now"
return topi.nn.softmax(inputs[0])

@reg.register_schedule("softmax")
def schedule_softmax(_, outs, target):
"""Schedule definition of softmax"""
Expand All @@ -73,13 +40,6 @@ def schedule_softmax(_, outs, target):


# log softmax
@reg.register_compute("log_softmax")
def compute_log_softmax(attrs, inputs, _):
"""Compute definition of softmax"""
axis = attrs.get_int("axis")
assert axis == -1, "only support axis == -1 for now"
return topi.nn.log_softmax(inputs[0])

@reg.register_schedule("log_softmax")
def schedule_log_softmax(_, outs, target):
"""Schedule definition of softmax"""
Expand All @@ -91,13 +51,6 @@ def schedule_log_softmax(_, outs, target):


# dense
@reg.register_compute("dense")
def compute_dense(attrs, inputs, _):
"""Compute definition of dense"""
if attrs.get_bool("use_bias"):
return topi.nn.dense(inputs[0], inputs[1], bias=inputs[2])
return topi.nn.dense(inputs[0], inputs[1])

@reg.register_schedule("dense")
def schedule_dense(_, outs, target):
"""Schedule definition of dense"""
Expand Down Expand Up @@ -175,18 +128,6 @@ def schedule_conv2d_transpose(attrs, outs, target):


# max_pool2d
@reg.register_compute("max_pool2d")
def compute_max_pool2d(attrs, inputs, _):
"""Compute definition of max_pool2d"""
pool_size = attrs.get_int_tuple("pool_size")
strides = attrs.get_int_tuple("strides")
padding = attrs.get_int_tuple("padding")
layout = attrs["layout"]
ceil_mode = attrs.get_bool("ceil_mode")
assert layout == "NCHW", "only support nchw for now"
return topi.nn.pool(inputs[0], pool_size, strides, padding,
pool_type='max', ceil_mode=ceil_mode)

@reg.register_schedule("max_pool2d")
def schedule_max_pool2d(_, outs, target):
"""Schedule definition of max_pool2d"""
Expand All @@ -197,18 +138,6 @@ def schedule_max_pool2d(_, outs, target):


# avg_pool2d
@reg.register_compute("avg_pool2d")
def compute_avg_pool2d(attrs, inputs, _):
"""Compute definition of avg_pool2d"""
pool_size = attrs.get_int_tuple("pool_size")
strides = attrs.get_int_tuple("strides")
padding = attrs.get_int_tuple("padding")
layout = attrs["layout"]
ceil_mode = attrs.get_bool("ceil_mode")
assert layout == "NCHW", "only support nchw for now"
return topi.nn.pool(inputs[0], pool_size, strides, padding,
pool_type='avg', ceil_mode=ceil_mode)

@reg.register_schedule("avg_pool2d")
def schedule_avg_pool2d(_, outs, target):
"""Schedule definition of avg_pool2d"""
Expand All @@ -219,13 +148,6 @@ def schedule_avg_pool2d(_, outs, target):


# global_max_pool2d
@reg.register_compute("global_max_pool2d")
def compute_global_max_pool2d(attrs, inputs, _):
"""Compute definition of global_max_pool2d"""
layout = attrs["layout"]
assert layout == "NCHW", "only support nchw for now"
return topi.nn.global_pool(inputs[0], pool_type='max')

@reg.register_schedule("global_max_pool2d")
def schedule_global_max_pool2d(_, outs, target):
"""Schedule definition of global_max_pool2d"""
Expand All @@ -236,13 +158,6 @@ def schedule_global_max_pool2d(_, outs, target):


# global_avg_pool2d
@reg.register_compute("global_avg_pool2d")
def compute_global_avg_pool2d(attrs, inputs, _):
"""Compute definition of global_avg_pool2d"""
layout = attrs["layout"]
assert layout == "NCHW", "only support nchw for now"
return topi.nn.global_pool(inputs[0], pool_type='avg')

@reg.register_schedule("global_avg_pool2d")
def schedule_global_avg_pool2d(_, outs, target):
"""Schedule definition of global_avg_pool2d"""
Expand Down
3 changes: 0 additions & 3 deletions python/nnvm/top/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,13 @@ def _compute(attrs, inputs, out_info):
return _compute

# sum
reg.register_compute("sum", _compute_reduce(topi.sum))
reg.register_pattern("sum", OpPattern.COMM_REDUCE)
reg.register_schedule("sum", _fschedule_reduce)

# max
reg.register_compute("max", _compute_reduce(topi.max))
reg.register_pattern("max", OpPattern.COMM_REDUCE)
reg.register_schedule("max", _fschedule_reduce)

# min
reg.register_compute("min", _compute_reduce(topi.min))
reg.register_pattern("min", OpPattern.COMM_REDUCE)
reg.register_schedule("min", _fschedule_reduce)
35 changes: 0 additions & 35 deletions python/nnvm/top/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,132 +43,97 @@ def _compute(attrs, x, _):
_fschedule_elemwise = _fschedule_injective

# copy
reg.register_compute("copy", _compute_unary(topi.identity))
reg.register_pattern("copy", OpPattern.ELEMWISE)
reg.register_schedule("copy", _fschedule_broadcast)

# exp
reg.register_compute("exp", _compute_unary(topi.exp))
reg.register_pattern("exp", OpPattern.ELEMWISE)
reg.register_schedule("exp", _fschedule_broadcast)

# sqrt
reg.register_compute("sqrt", _compute_unary(topi.sqrt))
reg.register_pattern("sqrt", OpPattern.ELEMWISE)
reg.register_schedule("sqrt", _fschedule_broadcast)

# log
reg.register_compute("log", _compute_unary(topi.log))
reg.register_pattern("log", OpPattern.ELEMWISE)
reg.register_schedule("log", _fschedule_broadcast)

# tanh
reg.register_compute("tanh", _compute_unary(topi.tanh))
reg.register_pattern("tanh", OpPattern.ELEMWISE)
reg.register_schedule("tanh", _fschedule_broadcast)

# negative
reg.register_compute("negative", _compute_unary(topi.negative))
reg.register_pattern("negative", OpPattern.ELEMWISE)
reg.register_schedule("negative", _fschedule_broadcast)

# sigmoid
reg.register_compute("sigmoid", _compute_unary(topi.sigmoid))
reg.register_pattern("sigmoid", OpPattern.ELEMWISE)
reg.register_schedule("sigmoid", _fschedule_broadcast)

# add_scalar
reg.register_compute("__add_scalar__",
_compute_binary_scalar(lambda x, y: x + y))
reg.register_pattern("__add_scalar__", OpPattern.ELEMWISE)
reg.register_schedule("__add_scalar__", _fschedule_broadcast)

# sub_calar
reg.register_compute("__sub_scalar__",
_compute_binary_scalar(lambda x, y: x - y))
reg.register_pattern("__sub_scalar__", OpPattern.ELEMWISE)
reg.register_schedule("__sub_scalar__", _fschedule_broadcast)

# rsub_scalar
reg.register_compute("__rsub_scalar__",
_compute_binary_scalar(lambda x, y: y - x))
reg.register_pattern("__rsub_scalar__", OpPattern.ELEMWISE)
reg.register_schedule("__rsub_scalar__", _fschedule_broadcast)

# mul_scalar
reg.register_compute("__mul_scalar__",
_compute_binary_scalar(lambda x, y: x * y))
reg.register_pattern("__mul_scalar__", OpPattern.ELEMWISE)
reg.register_schedule("__mul_scalar__", _fschedule_broadcast)

# div_scalar
reg.register_compute("__div_scalar__",
_compute_binary_scalar(lambda x, y: x / y))
reg.register_pattern("__div_scalar__", OpPattern.ELEMWISE)
reg.register_schedule("__div_scalar__", _fschedule_broadcast)

# rdiv_scalar
reg.register_compute("__rdiv_scalar__",
_compute_binary_scalar(lambda x, y: y / x))
reg.register_pattern("__rdiv_scalar__", OpPattern.ELEMWISE)
reg.register_schedule("__rdiv_scalar__", _fschedule_broadcast)

# pow_scalar
reg.register_compute("__pow_scalar__",
_compute_binary_scalar(tvm.power))
reg.register_pattern("__pow_scalar__", OpPattern.ELEMWISE)
reg.register_schedule("__pow_scalar__", _fschedule_broadcast)

# rpow_scalar
reg.register_compute("__rpow_scalar__",
_compute_binary_scalar(lambda x, y: tvm.power(y, x)))
reg.register_pattern("__rpow_scalar__", OpPattern.ELEMWISE)
reg.register_schedule("__rpow_scalar__", _fschedule_broadcast)

# elemwise_add
reg.register_compute("elemwise_add", _compute_binary(topi.broadcast_add))
reg.register_pattern("elemwise_add", OpPattern.BROADCAST)
reg.register_schedule("elemwise_add", _fschedule_broadcast)

# elemwise_sub
reg.register_compute("elemwise_sub", _compute_binary(topi.broadcast_sub))
reg.register_pattern("elemwise_sub", OpPattern.BROADCAST)
reg.register_schedule("elemwise_sub", _fschedule_broadcast)

# elemwise_mul
reg.register_compute("elemwise_mul", _compute_binary(topi.broadcast_mul))
reg.register_pattern("elemwise_mul", OpPattern.BROADCAST)
reg.register_schedule("elemwise_mul", _fschedule_broadcast)

# elemwise_div
reg.register_compute("elemwise_div", _compute_binary(topi.broadcast_div))
reg.register_pattern("elemwise_div", OpPattern.BROADCAST)
reg.register_schedule("elemwise_div", _fschedule_broadcast)

# broadcast_add
reg.register_compute("broadcast_add", _compute_binary(topi.broadcast_add))
reg.register_pattern("broadcast_add", OpPattern.BROADCAST)
reg.register_schedule("broadcast_add", _fschedule_broadcast)

# broadcast_sub
reg.register_compute("broadcast_sub", _compute_binary(topi.broadcast_sub))
reg.register_pattern("broadcast_sub", OpPattern.BROADCAST)
reg.register_schedule("broadcast_sub", _fschedule_broadcast)

# broadcast_mul
reg.register_compute("broadcast_mul", _compute_binary(topi.broadcast_mul))
reg.register_pattern("broadcast_mul", OpPattern.BROADCAST)
reg.register_schedule("broadcast_mul", _fschedule_broadcast)

# broadcast_div
reg.register_compute("broadcast_div", _compute_binary(topi.broadcast_div))
reg.register_pattern("broadcast_div", OpPattern.BROADCAST)
reg.register_schedule("broadcast_div", _fschedule_broadcast)

# broadcast_to
@reg.register_compute("broadcast_to")
def compute_broadcast_to(attrs, inputs, out_info):
"""Compute definition of softmax"""
return topi.broadcast_to(inputs[0], shape=out_info[0].shape)
reg.register_pattern("broadcast_to", OpPattern.BROADCAST)
reg.register_schedule("broadcast_to", _fschedule_broadcast)
Loading

0 comments on commit 96b7934

Please sign in to comment.