Skip to content

Commit

Permalink
Add NameSupply and GlobalVarSupply
Browse files Browse the repository at this point in the history
  • Loading branch information
gigiblender committed Jul 12, 2022
1 parent fc419df commit a7a9278
Show file tree
Hide file tree
Showing 41 changed files with 701 additions and 367 deletions.
11 changes: 8 additions & 3 deletions include/tvm/driver/driver_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,15 @@ TVM_DLL IRModule LowerPrimFunc(tvm::tir::PrimFunc func, const std::string& name,
* \param args The arguments to the function.
* \param name The name of the lowered function.
* \param binds Buffer assignments.
* \param global_var_supply The GlobalVarSupply to be used in the module.
* \param simple_mode Disables the loop partition pass. Defaults to false.
* \return The result module.
*/

TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array<te::Tensor>& args,
const std::string& name,
const std::unordered_map<te::Tensor, tir::Buffer>& binds,
bool simple_mode = false);
GlobalVarSupply global_var_supply, bool simple_mode = false);

/*!
* \brief Build an IRModule given a TE schedule, args and binds. This function also applies
Expand All @@ -115,13 +116,14 @@ TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array<te::Tensor>& args,
* \param args The arguments to the function (Array of Tensor, Buffer and Vars)
* \param name The name of the lowered function.
* \param binds Buffer assignments.
* \param global_var_supply The GlobalVarSupply to be used in the module.
* \param simple_mode Disables the loop partition pass. Defaults to false.
* \return The result module.
*/
TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array<ObjectRef>& args,
const std::string& name,
const std::unordered_map<te::Tensor, tir::Buffer>& binds,
bool simple_mode = false);
GlobalVarSupply global_var_supply, bool simple_mode = false);

/*!
* \brief Create an IRModule out of a TE Schedule. It does not apply lowering passes. If you want
Expand All @@ -130,10 +132,13 @@ TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array<ObjectRef>& args,
* \param args The arguments to the function.
* \param name The name of the lowered function.
* \param binds Buffer assignments.
* \param global_var_supply The GlobalVarSupply to be used in the module and when creating
* GlobalVars.
* \return The result module.
*/
IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args, const std::string& name,
const std::unordered_map<te::Tensor, tir::Buffer>& binds);
const std::unordered_map<te::Tensor, tir::Buffer>& binds,
GlobalVarSupply global_var_supply);
/*!
* \brief Build a device and host module for a specific target from an IRModule.
* \param funcs The functions to be built.
Expand Down
79 changes: 79 additions & 0 deletions include/tvm/ir/global_var_supply.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#ifndef TVM_IR_GLOBAL_VAR_SUPPLY_H_
#define TVM_IR_GLOBAL_VAR_SUPPLY_H_

#include <string>
#include <unordered_map>

#include "tvm/ir/expr.h"
#include "tvm/ir/name_supply.h"

namespace tvm {

class GlobalVarSupplyNode : public Object {
public:
GlobalVarSupplyNode() : GlobalVarSupplyNode(NameSupply("")) {}

explicit GlobalVarSupplyNode(NameSupply name_supply);

GlobalVar FreshGlobal(String name, bool add_prefix = true);

GlobalVar UniqueGlobalFor(const String& name, bool add_prefix = true);

void VisitAttrs(AttrVisitor* v) { v->Visit("name_supply", &name_supply_); }

NameSupply name_supply_;

static constexpr const char* _type_key = "GlobalVarSupply";
static constexpr const bool _type_has_method_sequal_reduce = false;
static constexpr const bool _type_has_method_shash_reduce = false;
TVM_DECLARE_FINAL_OBJECT_INFO(GlobalVarSupplyNode, Object);

private:
std::unordered_map<std::string, GlobalVar> name_to_var_map_;

friend class GlobalVarSupply;
};

class GlobalVarSupply : public ObjectRef {
public:
TVM_DLL explicit GlobalVarSupply(
const NameSupply& name_supply = NameSupply::NameSupplyWithPrefix(""),
std::unordered_map<std::string, GlobalVar> name_to_var_map = {});

TVM_DLL static GlobalVarSupply GlobalVarSupplyFromNameSupply(const NameSupply& name_supply);

TVM_DLL static GlobalVarSupply EmptySupply();

explicit GlobalVarSupply(ObjectPtr<Object> n) : ObjectRef(n) {}
/*! \return mutable pointers to the node. */
GlobalVarSupplyNode* operator->() const {
auto* ptr = get_mutable();
ICHECK(ptr != nullptr);
return static_cast<GlobalVarSupplyNode*>(ptr);
}

TVM_DEFINE_OBJECT_REF_COW_METHOD(GlobalVarSupplyNode);
};

} // namespace tvm

#endif // TVM_IR_GLOBAL_VAR_SUPPLY_H_
23 changes: 12 additions & 11 deletions include/tvm/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <tvm/ir/adt.h>
#include <tvm/ir/expr.h>
#include <tvm/ir/function.h>
#include <tvm/ir/global_var_supply.h>
#include <tvm/ir/type.h>
#include <tvm/parser/source_map.h>
#include <tvm/runtime/container/array.h>
Expand Down Expand Up @@ -64,6 +65,8 @@ class IRModuleNode : public Object {
/* \brief Additional attributes storing meta-data about the module. */
DictAttrs attrs;

GlobalVarSupply global_var_supply;

/*!
* \brief Get a module attribute.
*
Expand Down Expand Up @@ -125,6 +128,7 @@ class IRModuleNode : public Object {
v->Visit("global_type_var_map_", &global_type_var_map_);
v->Visit("source_map", &source_map);
v->Visit("attrs", &attrs);
v->Visit("global_var_supply", &global_var_supply);
}

TVM_DLL bool SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const;
Expand Down Expand Up @@ -323,14 +327,6 @@ class IRModuleNode : public Object {
/*! \brief Helper function for registering a typedef's constructors */
void RegisterConstructors(const GlobalTypeVar& var, const TypeData& type);

/*!
* \brief Returns a version of \p name which is unique amongst all function definitions in module.
*
* \param name The original name.
* \return Updated name which is unique.
*/
String GetUniqueName(const String& name);

/*! \brief A map from string names to global variables that
* ensures global uniqueness.
*/
Expand Down Expand Up @@ -362,12 +358,14 @@ class IRModule : public ObjectRef {
/*!
* \brief constructor
* \param functions Functions in the module.
* \param global_var_supply The GlobalVarSupply to be used in the module.
* \param type_definitions Type definitions in the module.
* \param import_set Set of imported files in the module.
* \param map The module source map.
* \param attrs The module attributes.
*/
TVM_DLL explicit IRModule(Map<GlobalVar, BaseFunc> functions,
GlobalVarSupply global_var_supply = GlobalVarSupply::EmptySupply(),
Map<GlobalTypeVar, TypeData> type_definitions = {},
std::unordered_set<String> import_set = {}, parser::SourceMap map = {},
DictAttrs attrs = {});
Expand Down Expand Up @@ -403,6 +401,7 @@ class IRModule : public ObjectRef {
*
* \param expr The expression to set as the main function to the module.
* \param global_funcs The global function map. Default empty.
* \param global_var_supply The GlobalVarSupply to be used in the module.
* \param type_definitions The global type definition map. Default empty.
* \param import_set Set of external modules already imported. Default empty.
*
Expand All @@ -413,16 +412,18 @@ class IRModule : public ObjectRef {
*/
static std::pair<IRModule, GlobalVar> FromExprInContext(
const RelayExpr& expr, const Map<GlobalVar, BaseFunc>& global_funcs = {},
GlobalVarSupply global_var_supply = GlobalVarSupply::EmptySupply(),
const Map<GlobalTypeVar, TypeData>& type_definitions = {},
std::unordered_set<String> import_set = {});

/*!
* \brief As for \p FromExprInContext, but assuming \p expr is bound to 'main' and no
* imports.
*/
TVM_DLL static IRModule FromExpr(const RelayExpr& expr,
const Map<GlobalVar, BaseFunc>& global_funcs = {},
const Map<GlobalTypeVar, TypeData>& type_definitions = {});
TVM_DLL static IRModule FromExpr(
const RelayExpr& expr, const Map<GlobalVar, BaseFunc>& global_funcs = {},
GlobalVarSupply global_var_supply = GlobalVarSupply::EmptySupply(),
const Map<GlobalTypeVar, TypeData>& type_definitions = {});

/*!
* \brief Parse text format source file into an IRModule.
Expand Down
89 changes: 89 additions & 0 deletions include/tvm/ir/name_supply.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#ifndef TVM_IR_NAME_SUPPLY_H_
#define TVM_IR_NAME_SUPPLY_H_

#include <string>
#include <unordered_map>

#include "tvm/ir/expr.h"

namespace tvm {

class NameSupplyNode : public Object {
public:
NameSupplyNode() : NameSupplyNode("") {}

explicit NameSupplyNode(const String& prefix);

String FreshName(const String& name, bool add_prefix = true);

String ReserveName(const String& name, bool add_prefix = true);

bool ContainsName(const String& name, bool add_prefix = true);

void Clear();

void VisitAttrs(AttrVisitor* v) { v->Visit("prefix", &prefix_); }

// Prefix for all GlobalVar names. It can be empty.
std::string prefix_;

static constexpr const char* _type_key = "NameSupply";
static constexpr const bool _type_has_method_sequal_reduce = false;
static constexpr const bool _type_has_method_shash_reduce = false;
TVM_DECLARE_FINAL_OBJECT_INFO(NameSupplyNode, Object);

private:
String prefix_module_name(const String& name);

std::string GetUniqueName(std::string name);

// Key is function_name. Value is a counter.
std::unordered_map<std::string, int> name_map;

friend class NameSupply;
};

class NameSupply : public ObjectRef {
public:
TVM_DLL NameSupply();

TVM_DLL explicit NameSupply(const String& prefix,
std::unordered_map<std::string, int> name_map = {});

TVM_DLL static NameSupply NameSupplyWithPrefix(const String& prefix = "");

TVM_DLL static NameSupply EmptySupply();

explicit NameSupply(ObjectPtr<Object> n) : ObjectRef(n) {}
/*! \return mutable pointers to the node. */
NameSupplyNode* operator->() const {
auto* ptr = get_mutable();
ICHECK(ptr != nullptr);
return static_cast<NameSupplyNode*>(ptr);
}

TVM_DEFINE_OBJECT_REF_COW_METHOD(NameSupplyNode);
};

} // namespace tvm

#endif // TVM_IR_NAME_SUPPLY_H_
5 changes: 3 additions & 2 deletions include/tvm/relay/interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,15 +181,16 @@ TypedPackedFunc<ObjectRef(Array<Expr>)> EvalFunction(IRModule mod, Expr expr, De
*
* \param expr An expression to evaluate.
* \param type_definitions Global type definitions which \p expr may references.
* \param global_var_supply The GlobalVarSupply to be used during evaluation.
* \param import_set Already imported external modules.
* \param device The device on which all primitives will be executed.
* \param target The compiler target flag for compiling primitives.
* \param attrs Attributes for the expression to be evaluated with
* @return The object representing the result.
*/
ObjectRef Eval(Expr expr, Map<GlobalTypeVar, TypeData> type_definitions,
std::unordered_set<String> import_set, Device device, Target target,
Map<String, ObjectRef> attrs = {});
GlobalVarSupply global_var_supply, std::unordered_set<String> import_set,
Device device, Target target, Map<String, ObjectRef> attrs = {});

} // namespace relay
} // namespace tvm
Expand Down
16 changes: 12 additions & 4 deletions python/tvm/ir/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""IRModule that holds the functions and type definitions."""
from tvm._ffi.base import string_types
import tvm._ffi
from tvm.ir.supply import GlobalVarSupply

from .base import Node
from . import expr as _expr
Expand All @@ -36,7 +37,7 @@ class IRModule(Node):
Map of global var to BaseFunc
"""

def __init__(self, functions=None, type_definitions=None):
def __init__(self, functions=None, type_definitions=None, globar_var_supply=None):
if functions is None:
functions = {}
elif isinstance(functions, dict):
Expand All @@ -59,7 +60,11 @@ def __init__(self, functions=None, type_definitions=None):
raise TypeError("Expect type_definitions to be Dict[GlobalTypeVar, Type]")
mapped_type_defs[k] = v
type_definitions = mapped_type_defs
self.__init_handle_by_constructor__(_ffi_api.IRModule, functions, type_definitions)
if globar_var_supply is None:
globar_var_supply = GlobalVarSupply()
self.__init_handle_by_constructor__(
_ffi_api.IRModule, functions, type_definitions, globar_var_supply
)

def __setitem__(self, var, val):
"""Add a mapping to the module.
Expand Down Expand Up @@ -217,7 +222,7 @@ def get_type(self, name):
return tuple([ty_var] + list(ty_data.constructors))

@staticmethod
def from_expr(expr, functions=None, type_defs=None):
def from_expr(expr, functions=None, type_defs=None, global_var_supply=None):
"""Construct a module from a standalone expression.
Parameters
Expand All @@ -238,9 +243,12 @@ def from_expr(expr, functions=None, type_defs=None):
where expr is set as the entry point
(wrapped in a function if necessary)
"""
global_var_supply = (
global_var_supply if global_var_supply is not None else GlobalVarSupply()
)
funcs = functions if functions is not None else {}
defs = type_defs if type_defs is not None else {}
return _ffi_api.Module_FromExpr(expr, funcs, defs)
return _ffi_api.Module_FromExpr(expr, funcs, global_var_supply, defs)

def _import(self, file_to_import):
return _ffi_api.Module_Import(self, file_to_import)
Expand Down
Loading

0 comments on commit a7a9278

Please sign in to comment.