Skip to content

Commit

Permalink
Merge branch 'apache:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
canesche authored Feb 6, 2024
2 parents 782d2e2 + 6a3fadc commit 46f5fbe
Show file tree
Hide file tree
Showing 132 changed files with 5,267 additions and 1,742 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/flashinfer
Submodule flashinfer updated 56 files
+1 −1 .clang-format
+54 −0 .github/workflows/build-doc.yml
+106 −0 .github/workflows/release_wheel.yml
+1 −0 docs/.gitignore
+20 −0 docs/Makefile
+86 −0 docs/conf.py
+20 −0 docs/index.rst
+35 −0 docs/make.bat
+9 −0 docs/requirements.txt
+127 −19 include/flashinfer/cascade.cuh
+179 −223 include/flashinfer/decode.cuh
+148 −124 include/flashinfer/handler.cuh
+12 −12 include/flashinfer/layout.cuh
+99 −122 include/flashinfer/page.cuh
+125 −117 include/flashinfer/prefill.cuh
+26 −47 include/flashinfer/rope.cuh
+85 −69 include/flashinfer/utils.cuh
+12 −0 python/MANIFEST.in
+49 −28 python/csrc/batch_decode.cu
+74 −50 python/csrc/batch_prefill.cu
+35 −2 python/csrc/cascade.cu
+5 −2 python/csrc/flashinfer_ops.cu
+26 −13 python/csrc/flashinfer_ops.h
+84 −0 python/csrc/page.cu
+1 −1 python/csrc/single_decode.cu
+15 −11 python/flashinfer/__init__.py
+323 −0 python/flashinfer/cascade.py
+327 −0 python/flashinfer/decode.py
+0 −680 python/flashinfer/ops/__init__.py
+0 −12 python/flashinfer/ops/utils.py
+42 −0 python/flashinfer/page.py
+324 −0 python/flashinfer/prefill.py
+56 −0 python/flashinfer/utils.py
+1 −0 python/include
+69 −40 python/setup.py
+29 −12 python/tests/test_batch_decode_kernels.py
+38 −13 python/tests/test_batch_prefill_kernels.py
+204 −3 python/tests/test_shared_prefix_kernels.py
+1 −0 python/version.txt
+8 −0 scripts/ci-flashinfer.env.example
+27 −0 scripts/ci-flashinfer.service
+0 −0 scripts/formatter.sh
+46 −0 scripts/run-ci-build-wheel.sh
+22 −12 src/bench_batch_decode.cu
+53 −35 src/bench_cascade.cu
+7 −7 src/bench_single_decode.cu
+4 −4 src/bench_single_prefill.cu
+11 −8 src/cpu_reference.h
+19 −14 src/test_batch_decode.cu
+16 −12 src/test_batch_prefill.cu
+190 −44 src/test_cascade.cu
+76 −43 src/test_page.cu
+11 −10 src/test_single_decode.cu
+11 −10 src/test_single_prefill.cu
+53 −41 src/tvm_wrapper.cu
+1 −0 version.txt
134 changes: 76 additions & 58 deletions apps/microtvm/poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion apps/microtvm/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ black = "^19.10b0"
matplotlib = "^3.2"
Image = "^1.5"
recommonmark = "^0.6"
pillow = "==10.0.1"
pillow = "==10.2.0"
pyformat = "^0.7"
pylint = "^2.4"
pytest = "==7.2.1"
Expand Down
1 change: 1 addition & 0 deletions ci/scripts/github/github_tvmbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,7 @@ def rerun_jenkins_ci(self) -> None:
"tvm-minimal-cross-isa",
"tvm-riscv",
"tvm-wasm",
"tvm-unity",
]
for name in job_names:
url = JENKINS_URL + f"job/{name}/job/PR-{self.number}/buildWithParameters"
Expand Down
1 change: 0 additions & 1 deletion docker/install/ubuntu2004_install_python_package.sh
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,4 @@ pip3 install --upgrade \
junitparser==2.4.2 \
six \
tornado \
pytest-lazy-fixture \
git+https://github.com/jax-ml/ml_dtypes.git@v0.2.0
1 change: 0 additions & 1 deletion docker/install/ubuntu_install_python_package.sh
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,4 @@ pip3 install --upgrade \
junitparser==2.4.2 \
six \
tornado \
pytest-lazy-fixture \
ml_dtypes
25 changes: 25 additions & 0 deletions include/tvm/ir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,31 @@ TVM_DLL Pass CreateModulePass(
const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func, int opt_level,
String name, Array<runtime::String> required, bool traceable = false);

/*
* \brief Utility to apply a pass to specific functions in an IRModule
*
* TVM uses IRModule to IRModule transformations at all stages of
* lowering. These transformations may be useful when hand-writing an
* optimized model, or to perform optimizations on specific kernels
* within an IRModule. This utility allows a pass to be applied to a
* specified function, without altering other functions in the module.
*
* \param pass The IRModule to IRModule pass to be applied.
*
* \param func_name_regex A regex used to select the functions to be
* updated. The pass will be applied to all functions whose name
* matches the regex.
*
* \param error_if_no_function_matches_regex Specifies the behavior if
* an IRModule does not contain any function matching the provided
* regex. If true, an error will be raised. If false (default),
* the IRModule will be returned unmodified.
*
* \return The modified IRModule to IRModule pass.
*/
TVM_DLL Pass ApplyPassToFunction(Pass pass, String func_name_regex,
bool error_if_no_function_matches_regex = false);

/*!
* \brief A special trace pass that prints the header and IR to LOG(INFO).
* \param header The header to be attached to the output.
Expand Down
21 changes: 17 additions & 4 deletions include/tvm/relax/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,9 @@ TVM_DLL Pass BindSymbolicVars(Map<ObjectRef, PrimExpr> binding_map,
Optional<String> func_name = NullOpt);

/*!
* \brief Fold constant expressions.
* \brief Fold constant expressions within dataflow blocks.
*
* \note ConvertToDataflow may need to be called first to provide dataflow blocks.
*
* \return The Pass.
*/
Expand Down Expand Up @@ -458,6 +460,8 @@ class PatternCheckContext : public ObjectRef {
* of the return value as the target. If it is not specified, the first return value will be the
* target.
* \return The Pass.
*
* \note ConvertToDataflow may need to be called first to provide dataflow blocks.
*/
TVM_DLL Pass Gradient(String func_name, Optional<Array<Var>> require_grads = NullOpt,
int target_index = 0);
Expand All @@ -477,6 +481,8 @@ TVM_DLL Pass Gradient(String func_name, Optional<Array<Var>> require_grads = Nul
* This must be True if the created composite functions are intended to be offloaded to
* an external backend without using the MergeCompositeFunctions pass.
* \return The Pass.
*
* \note Only operates within dataflow blocks. ConvertToDataflow may need to be called first.
*/
TVM_DLL Pass FuseOpsByPattern(const tvm::Array<FusionPattern>& patterns, bool bind_constants = true,
bool annotate_codegen = false);
Expand Down Expand Up @@ -548,6 +554,7 @@ TVM_DLL Pass AlterOpImpl(const Map<String, tir::PrimFunc>& op_impl_map,
* \brief Layout conversion pass.
* \param desired_layouts The desired layouts for some operators.
* \return The Pass.
* \note Operates only on dataflow blocks. ConvertToDataflow may need to be called first.
*/
TVM_DLL Pass ConvertLayout(Map<String, Array<String>> desired_layouts);

Expand All @@ -564,10 +571,13 @@ TVM_DLL Pass ConvertToDataflow(int min_size = 2);
* \brief Dead code elimination.
* \sa RemoveAllUnused
* Currently it removes:
* 1. Unused local VarBindings in a DataflowBlock.
* 2. Unused DataflowBlocks in a function.
* 3. Unused Relax functions in the module.
* 1. Unused local VarBindings
* (those where the bound var is unused and no impure operation is used).
* 2. Unused Relax functions in the module.
* We detect the call chain from the entry function, and remove all unused functions.
*
* Any binding blocks that are left empty will be removed by the normalizer.
*
* \return The Pass.
*/
TVM_DLL Pass DeadCodeElimination(Array<runtime::String> entry_functions);
Expand All @@ -578,6 +588,7 @@ TVM_DLL Pass DeadCodeElimination(Array<runtime::String> entry_functions);
* Supported operators will be replaced by calls to `call_tir_inplace` that invoke in-place
* PrimFunc implementations of those operators (which are based on the legalizations of those
* operators).
* \note ConvertToDataflow may need to be called first to provide dataflow blocks.
* \return The pass.
*/
TVM_DLL Pass DataflowUseInplaceCalls();
Expand All @@ -589,6 +600,8 @@ TVM_DLL Pass DataflowUseInplaceCalls();
* \param fp16_input_names The names of function parameters whose dtype should become fp16. The
* function signature would change accordingly.
* \return The Pass.
*
* \note Mainly operates within dataflow blocks. ConvertToDataflow may need to be called first.
*/
TVM_DLL Pass ToMixedPrecision(const DataType& out_dtype,
Optional<Array<String>> fp16_input_names = NullOpt);
Expand Down
5 changes: 5 additions & 0 deletions include/tvm/runtime/threading_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,11 @@ inline void parallel_launch_with_threading_backend(T flambda) {

template <typename T>
inline void parallel_for_with_threading_backend(T flambda, int64_t begin, int64_t end) {
if (end - begin == 1) {
flambda(begin);
return;
}

auto flaunch = [begin, end, flambda](int task_id, int num_task) {
// For each thread, do static division and call into flambda.
int64_t total_len = end - begin;
Expand Down
6 changes: 6 additions & 0 deletions include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -909,6 +909,12 @@ TVM_DLL const Op& anylist_setitem_call_packed();
*/
TVM_DLL const Op& anylist_setitem_call_cpacked();

/*!
* \brief Get the target's vscale value. It will be lowered to llvm.vscale intrinsic
* (https://llvm.org/docs/LangRef.html#llvm-vscale-intrinsic)
*/
TVM_DLL const Op& vscale();

/*! \brief The kind of structure field info used in intrinsic */
enum TVMStructFieldKind : int {
// array head address
Expand Down
6 changes: 2 additions & 4 deletions python/tvm/contrib/emcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,12 @@ def create_tvmjs_wasm(output, objects, options=None, cc="emcc"):
cmd += ["-O3"]
cmd += ["-std=c++17"]
cmd += ["--no-entry"]
# temp disable for backward compact
# can enable after emsdk updates
# cmd += ["-fwasm-exception"]
cmd += ["-fwasm-exception"]
cmd += ["-s", "WASM_BIGINT=1"]
cmd += ["-s", "ERROR_ON_UNDEFINED_SYMBOLS=0"]
cmd += ["-s", "STANDALONE_WASM=1"]
cmd += ["-s", "ALLOW_MEMORY_GROWTH=1"]
cmd += ["-s", "TOTAL_MEMORY=80MB"]
cmd += ["-s", "TOTAL_MEMORY=160MB"]

objects = [objects] if isinstance(objects, str) else objects

Expand Down
28 changes: 16 additions & 12 deletions python/tvm/contrib/msc/core/frontend/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,12 @@ def _to_data(ref_t, data):
return data

weights = {t.name: _to_data(t, d) for t, d in t_weights.items() if graph.has_tensor(t.name)}
return weights
# sort the weights by graph weights
graph_weights = {}
for weight in graph.get_weights():
assert weight.name in weights, "Missing weight " + str(weight)
graph_weights[weight.name] = weights[weight.name]
return graph_weights


def from_relax(
Expand Down Expand Up @@ -115,13 +120,10 @@ def from_relax(
patterns = get_patterns_with_prefix("msc.")
passes = [
msc_transform.SetExprName(),
msc_transform.SetExprLayout(trans_config.get("allow_layout_missing", True)),
tvm.relax.transform.FuseOpsByPattern(
patterns, bind_constants=False, annotate_codegen=False
),
msc_transform.SetExprName(entry_name=entry, target=trans_config.get("target", "")),
msc_transform.SetExprLayout(
trans_config.get("allow_layout_missing", True), entry_name=entry
),
]
mod = tvm.transform.Sequential(passes)(mod)
graph = _ffi_api.BuildFromRelax(mod, entry, msc_utils.dump_dict(build_config))
Expand Down Expand Up @@ -309,13 +311,12 @@ def _partition_mod(mod, as_msc=True):
patterns = get_patterns_with_prefix(target)
passes = [
msc_transform.SetExprName(),
msc_transform.SetExprLayout(trans_config.get("allow_layout_missing", True)),
tvm.relax.transform.FuseOpsByPattern(patterns, bind_constants=not as_msc),
msc_transform.BindShape(),
msc_transform.InlineParams(),
msc_transform.FuseTuple(target),
tvm.relax.transform.MergeCompositeFunctions(),
msc_transform.SetBYOCAttrs(target),
msc_transform.SetExprName(target=target),
msc_transform.SetExprLayout(trans_config.get("allow_layout_missing", True)),
]
return tvm.transform.Sequential(passes)(mod)

Expand All @@ -331,9 +332,12 @@ def _is_target_func(func):
assert len(func_names) == 1, "More than 1 target func is found: " + str(msc_mod)
BYOCChecker().check(func_names, msc_mod[entry])

graphs_info, all_weights = [], _ffi_api.GetRelaxWeights(msc_mod, entry)
ref_weights = _ffi_api.GetRelaxWeights(msc_mod, entry)
graphs, weights = [], {}
for name in func_names:
build_config.update({"graph_name": msc_mod[name].attrs["byoc_name"], "byoc_entry": name})
graph_name = msc_mod[name].attrs[_ffi_api.ToAttrKey("unique")]
build_config.update({"graph_name": graph_name, "byoc_entry": name})
graph = _ffi_api.BuildFromRelax(msc_mod, entry, msc_utils.dump_dict(build_config))
graphs_info.append((graph, normalize_weights(all_weights, graph)))
return _partition_mod(mod, False), graphs_info
graphs.append(graph)
weights.update(normalize_weights(ref_weights, graph))
return _partition_mod(mod, False), graphs, weights
13 changes: 13 additions & 0 deletions python/tvm/contrib/msc/core/ir/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,19 @@ def get_nodes(self) -> Iterable[MSCJoint]:
for n in self.node_names:
yield self.find_node(n)

def get_weights(self) -> Iterable[MSCTensor]:
"""Get all the weights in the graph.
Returns
-------
weights: generator<MSCTensor>
The generator of weights.
"""

for node in self.get_nodes():
for weight in node.get_weights().values():
yield weight

def input_at(self, idx: int) -> MSCTensor:
"""Get input at idx.
Expand Down
Loading

0 comments on commit 46f5fbe

Please sign in to comment.