Skip to content

Commit

Permalink
[TIR] Remove PrimFuncNode::preflattened_buffer_map (#10940)
Browse files Browse the repository at this point in the history
`PrimFuncNode::preflattened_buffer_map` was introduced in
#9727, in order to maintain a record
of the pre-flattened buffer shape until it can be used in
`MakePackedAPI`.  This commit instead maintains the pre-flattened
shapes in `PrimFuncNode::buffer_map`, while the body of the function
uses a flattened buffer alias, as described in
[RFC#70](apache/tvm-rfcs#70)
  • Loading branch information
Lunderberg authored Nov 16, 2022
1 parent 0d9b185 commit 78b5322
Show file tree
Hide file tree
Showing 50 changed files with 462 additions and 780 deletions.
3 changes: 0 additions & 3 deletions include/tvm/script/ir_builder/tir/frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,6 @@ class PrimFuncFrameNode : public TIRFrameNode {
Optional<Type> ret_type;
/*! \brief Maps some parameters to specific Buffer data structures. */
Map<tvm::tir::Var, tvm::tir::Buffer> buffer_map;
/*! \brief The buffer map prior to flattening. */
Map<tvm::tir::Var, tvm::tir::Buffer> preflattened_buffer_map;
/*! \brief Additional attributes storing the meta-data */
Optional<Map<String, ObjectRef>> attrs;
/*! \brief The variable map bound to thread env. */
Expand All @@ -90,7 +88,6 @@ class PrimFuncFrameNode : public TIRFrameNode {
v->Visit("args", &args);
v->Visit("ret_type", &ret_type);
v->Visit("buffer_map", &buffer_map);
v->Visit("preflattened_buffer_map", &preflattened_buffer_map);
v->Visit("attrs", &attrs);
v->Visit("env_threads", &env_threads);
v->Visit("root_alloc_buffers", &root_alloc_buffers);
Expand Down
20 changes: 0 additions & 20 deletions include/tvm/script/ir_builder/tir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,26 +114,6 @@ Buffer MatchBuffer(ObjectRef param, Array<PrimExpr> shape, DataType dtype = Data
int align = -1, int offset_factor = 0, String buffer_type = "default",
Array<IntImm> axis_separators = {});

/*!
* \brief The pre-flattened buffer statement.
* \param postflattened_buffer The original buffer to be flattened.
* \param shape The type of the buffer prior to flattening.
* \param dtype The data type in the content of the buffer.
* \param data The pointer to the head of the data.
* \param strides The strides of each dimension.
* \param elem_offset The offset in terms of number of dtype elements (including lanes).
* \param storage_scope The optional storage scope of buffer data pointer.
* \param align The alignment requirement of data pointer in bytes.
* \param offset_factor The factor of elem_offset field.
* \param buffer_type The buffer type.
* \param axis_separators The separators between input axes when generating flattened output axes.
*/
void PreflattenedBuffer(Buffer postflattened_buffer, Array<PrimExpr> shape,
DataType dtype = DataType::Float(32), Optional<Var> data = NullOpt,
Array<PrimExpr> strides = {}, PrimExpr elem_offset = PrimExpr(),
String storage_scope = "global", int align = -1, int offset_factor = 0,
String buffer_type = "default", Array<IntImm> axis_separators = {});

/*!
* \brief The block declaration statement.
* \param name The name of the block.
Expand Down
43 changes: 11 additions & 32 deletions include/tvm/tir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,33 +88,22 @@ class PrimFuncNode : public BaseFuncNode {
* While we could have express parameter unpacking and constraint using
* normal statements, making buffer_map as first class citizen of PrimFunc
* will make program analysis much easier.
*/
Map<tir::Var, Buffer> buffer_map;

/*! \brief The buffer map prior to flattening.
*
* This contains the buffers as they exists prior to flattening, and
* is used for validating an input tensor passed into the packed
* API. Any buffer that is present in `buffer_map` but not present
* in `preflattened_buffer_map` is assumed to be the same before
* and after flattening (e.g. a 1-d tensor that is backed by 1-d
* flat memory).
*
* TODO(Lunderberg): Remove preflattened_buffer_map, and instead
* declare each flattened buffer as aliasing the original tensor
* shape. This should include improving the StmtExprMutator to
* provide easier interactions with Buffer objects, so that the
* bookkeeping of relationships between buffers doesn't need to be
* repeated across several transforms.
* Prior to buffer flattening, which is performed either in
* StorageFlatten for TE-based schedules or in FlattenBuffer for
* TIR-based schedules, these buffer objects are used directly in
* the body of the function. After buffer flattening, these buffer
* objects remain unflattened for use in argument validation, but
* all usage in the body of the function is done through a
* flattened alias of the buffer.
*/
Map<tir::Var, Buffer> preflattened_buffer_map;
Map<tir::Var, Buffer> buffer_map;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("params", &params);
v->Visit("body", &body);
v->Visit("ret_type", &ret_type);
v->Visit("buffer_map", &buffer_map);
v->Visit("preflattened_buffer_map", &preflattened_buffer_map);
v->Visit("attrs", &attrs);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
Expand All @@ -123,15 +112,13 @@ class PrimFuncNode : public BaseFuncNode {
bool SEqualReduce(const PrimFuncNode* other, SEqualReducer equal) const {
// visit params and buffer_map first as they contains defs.
return equal.DefEqual(params, other->params) && equal(buffer_map, other->buffer_map) &&
equal(preflattened_buffer_map, other->preflattened_buffer_map) &&
equal(ret_type, other->ret_type) && equal(body, other->body) &&
equal(attrs, other->attrs);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce.DefHash(params);
hash_reduce(buffer_map);
hash_reduce(preflattened_buffer_map);
hash_reduce(ret_type);
hash_reduce(body);
hash_reduce(attrs);
Expand Down Expand Up @@ -169,21 +156,13 @@ class PrimFunc : public BaseFunc {
* PrimFunc. (e.g. a buffer of shape ``[1024]`` originally
* generated as a tensor of shape ``[32, 32]``)
*
* \param preflattened_buffer_map The buffer map for
* parameter buffer unpacking. This contains buffer
* objects as they are expected to be passed in by the
* callee. (e.g. a buffer of shape ``[32, 32]`` originally
* generated as a tensor of shape ``[32, 32]``)
*
* \param attrs Additional function attributes.
*
* \param span The location of this object in the source code.
*/
TVM_DLL PrimFunc(
Array<tir::Var> params, Stmt body, Type ret_type = VoidType(),
Map<tir::Var, Buffer> buffer_map = Map<tir::Var, Buffer>(),
Optional<Map<tir::Var, Buffer>> preflattened_buffer_map = Optional<Map<tir::Var, Buffer>>(),
DictAttrs attrs = NullValue<DictAttrs>(), Span span = Span());
TVM_DLL PrimFunc(Array<tir::Var> params, Stmt body, Type ret_type = VoidType(),
Map<tir::Var, Buffer> buffer_map = Map<tir::Var, Buffer>(),
DictAttrs attrs = NullValue<DictAttrs>(), Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(PrimFunc, BaseFunc, PrimFuncNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(PrimFuncNode);
Expand Down
77 changes: 53 additions & 24 deletions python/tvm/relay/backend/contrib/ethosu/tir/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,6 @@ def _ftransform(f, mod, ctx):
new_body,
f.ret_type,
new_buffer_map,
f.preflattened_buffer_map,
f.attrs,
f.span,
)
Expand Down Expand Up @@ -327,7 +326,7 @@ def EncodeConstants(const_dict):
"""
new_const_dict = {}

def collect_encoding_definitions(stmt, old_buffer_to_const):
def collect_encoding_definitions(stmt, old_buffer_var_to_const):
# Map from copy destination to copy source.
copy_map = {}
# List of buffer copies that occurred
Expand Down Expand Up @@ -376,7 +375,7 @@ def _declare_constant_buffer(old_buffer, encoded_constants, split_idx):
def _encode_weights_or_bias(buffer1, buffer2, stmt, encode_func):
"""Encode the weights or align the bias either for one or two cores,
depending on the variant."""
constant = old_buffer_to_const[buffer1]
constant = old_buffer_var_to_const[buffer1.data]

# If we have just one core, encode the whole constant
if buffer2 is None:
Expand Down Expand Up @@ -471,7 +470,12 @@ def _visit(stmt):
}

def transform_stmt(
stmt, buf_remap, var_remap, pointer_to_buffer, new_buffer_to_const, new_buffer_to_split_idx
stmt,
buf_remap,
var_remap,
pointer_to_buffer,
new_buffer_var_to_const,
new_buffer_to_split_idx,
):
def _visit_rewrite(stmt):
if isinstance(stmt, tvm.tir.Call):
Expand All @@ -485,7 +489,7 @@ def _visit_rewrite(stmt):
# encoded buffer, the current should be a length.
if (
isinstance(prev_arg, tvm.tir.BufferLoad)
and prev_arg.buffer in new_buffer_to_const
and prev_arg.buffer.data in new_buffer_var_to_const
):
buffer_size = np.prod(list(prev_arg.buffer.shape))
arg = buffer_size
Expand Down Expand Up @@ -554,28 +558,56 @@ def _visit_rewrite(stmt):
["tir.Call", "tir.Allocate", "tir.BufferLoad", "tir.AttrStmt"],
)

def _collect_parameter_buffer_aliases(prim_func):
buffer_vars = {}
for param in prim_func.params:
if param in prim_func.buffer_map:
buf = prim_func.buffer_map[param]
buffer_vars[buf.data] = {buf}

def visit(node):
if isinstance(node, (tvm.tir.BufferStore, tvm.tir.BufferLoad, tvm.tir.DeclBuffer)):
buf = node.buffer
if buf.data in buffer_vars:
buffer_vars[buf.data].add(buf)

tvm.tir.stmt_functor.post_order_visit(prim_func.body, visit)
return buffer_vars

def _ftransform(f, mod, ctx):
param_buffer_var_usage = _collect_parameter_buffer_aliases(f)

# Step 0: Unpack the constant dictionary in terms of the
# functions buffers.
old_buffer_to_const = {}
old_buffer_var_to_const = {}
for i, param in enumerate(f.params):
if i in const_dict:
old_buffer_to_const[f.buffer_map[param]] = const_dict[i]
old_buffer_var_to_const[f.buffer_map[param].data] = const_dict[i]

# Step 1: Collect information on the buffers that will be
# replaced by encodings.
buffer_information = collect_encoding_definitions(f.body, old_buffer_to_const)
buffer_information = collect_encoding_definitions(f.body, old_buffer_var_to_const)

# Step 2: Generate variable/buffer remaps, based on the
# collected information.
buf_remap = {}
new_buffer_to_const = {}
new_buffer_var_to_const = {}
new_buffer_to_split_idx = {}

def define_remap(old_buf, new_buf):
try:
old_buffers = param_buffer_var_usage[old_buf.data]
except KeyError:
old_buffers = [old_buf]

for old_buffer in old_buffers:
buf_remap[old_buffer] = new_buf

# Any encoded buffers must be replaced
for info in buffer_information["constant_buffer_replacements"]:
buf_remap[info["old_buffer"]] = info["new_buffer"]
new_buffer_to_const[info["new_buffer"]] = info["encoded_constants"]
define_remap(info["old_buffer"], info["new_buffer"])

new_buffer_var_to_const[info["new_buffer"].data] = info["encoded_constants"]

if info["split_idx"]:
new_buffer_to_split_idx[info["new_buffer"]] = info["split_idx"]
Expand All @@ -596,9 +628,11 @@ def _ftransform(f, mod, ctx):
name=copy_dest.name,
scope=copy_dest.scope(),
)
buf_remap[copy_dest] = new_dest
if copy_source in new_buffer_to_const:
new_buffer_to_const[new_dest] = new_buffer_to_const[copy_source]
define_remap(copy_dest, new_dest)
if copy_source.data in new_buffer_var_to_const:
new_buffer_var_to_const[new_dest.data] = new_buffer_var_to_const[
copy_source.data
]

if copy_source in new_buffer_to_split_idx:
new_buffer_to_split_idx[new_dest] = new_buffer_to_split_idx[copy_source]
Expand All @@ -615,7 +649,7 @@ def _ftransform(f, mod, ctx):
buf_remap,
var_remap,
pointer_to_buffer,
new_buffer_to_const,
new_buffer_var_to_const,
new_buffer_to_split_idx,
)

Expand All @@ -626,10 +660,10 @@ def _ftransform(f, mod, ctx):
if buffer in buf_remap:
buffer = buf_remap[buffer]

if buffer in new_buffer_to_const:
new_const_dict[i] = new_buffer_to_const[buffer].flatten()
elif buffer in old_buffer_to_const:
new_const_dict[i] = old_buffer_to_const[buffer].flatten()
if buffer.data in new_buffer_var_to_const:
new_const_dict[i] = new_buffer_var_to_const[buffer.data].flatten()
elif buffer.data in old_buffer_var_to_const:
new_const_dict[i] = old_buffer_var_to_const[buffer.data].flatten()

new_buffer_map[param] = buffer

Expand All @@ -638,7 +672,6 @@ def _ftransform(f, mod, ctx):
new_body,
f.ret_type,
new_buffer_map,
f.preflattened_buffer_map,
f.attrs,
f.span,
)
Expand Down Expand Up @@ -873,7 +906,6 @@ def CreatePrimFuncWithoutConstants(const_dict):
def _ftransform(f, mod, ctx):
new_params = list()
new_buffer_map = dict()
new_preflattened_buffer_map = dict()
for param_idx in const_dict.keys():
# We are using buffer_var to key the constants as
# PrimFunc params of constants will be removed.
Expand All @@ -882,14 +914,11 @@ def _ftransform(f, mod, ctx):
if i not in const_dict.keys():
new_params.append(param)
new_buffer_map[param] = f.buffer_map[param]
if param in f.preflattened_buffer_map:
new_preflattened_buffer_map[param] = f.preflattened_buffer_map[param]
return tvm.tir.PrimFunc(
new_params,
f.body,
f.ret_type,
new_buffer_map,
new_preflattened_buffer_map,
f.attrs,
f.span,
)
Expand Down
69 changes: 0 additions & 69 deletions python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,74 +314,6 @@ def match_buffer(
)


def preflattened_buffer(
postflattened: Buffer,
shape: Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral],
dtype: str = "float32",
data: Var = None,
strides: List[PrimExpr] = None,
elem_offset: PrimExpr = None,
scope: str = "global",
align: int = -1,
offset_factor: int = 0,
buffer_type: str = "default",
axis_separators: List[int] = None,
) -> None:
"""The pre-flattened buffer statement.
Parameters
----------
postflattened : Buffer
The original buffer to be flattened.
shape : Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral]
The type of the buffer prior to flattening.
dtype : str
The data type in the content of the buffer.
data : Var
The pointer to the head of the data.
strides : List[PrimExpr]
The strides of each dimension.
elem_offset : PrimExpr
The offset in terms of number of dtype elements (including lanes).
scope : str
The optional storage scope of buffer data pointer.
align : int
The alignment requirement of data pointer in bytes.
offset_factor : int
The factor of elem_offset field.
buffer_type : str
The buffer type.
axis_separators : List[int]
The separators between input axes when generating flattened output axes.
"""
shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
if strides is None:
strides = []
_ffi_api.PreflattenedBuffer( # type: ignore[attr-defined] # pylint: disable=no-member
postflattened,
shape,
dtype,
data,
strides,
elem_offset,
scope,
align,
offset_factor,
buffer_type,
axis_separators,
)


def block(name: str = "", no_realize: bool = False) -> frame.BlockFrame:
"""The block declaration statement.
Expand Down Expand Up @@ -1697,7 +1629,6 @@ def f():
"func_attr",
"func_ret",
"match_buffer",
"preflattened_buffer",
"block",
"init",
"where",
Expand Down
Loading

0 comments on commit 78b5322

Please sign in to comment.