Skip to content

Commit

Permalink
[WEB] WebGPU Codegen
Browse files Browse the repository at this point in the history
This PR provide an implementation of WebGPU codegen.
Previously we relied on SPIRV codegen for WebGPU, which
is deprecated in favor of the WGSL shading language.
Pass limited testing on elementwise via chrome.
Likely we will do future iterations.
  • Loading branch information
tqchen committed Feb 20, 2023
1 parent 9f28b1d commit 0a6b038
Show file tree
Hide file tree
Showing 16 changed files with 817 additions and 105 deletions.
17 changes: 17 additions & 0 deletions src/target/intrin_rule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
#include <tvm/topi/elemwise.h>

namespace tvm {
namespace codegen {
Expand Down Expand Up @@ -118,6 +119,22 @@ TVM_REGISTER_OP("tir.nearbyint")
TVM_REGISTER_OP("tir.pow").set_attr<FLowerIntrinsic>("default.FLowerIntrinsic",
DispatchPureExtern<FloatSuffix>);

PrimExpr DispatchFastErf(const PrimExpr& e) {
LOG(WARNING) << "fast_erf will be used instead of erf";
const CallNode* call = e.as<CallNode>();
ICHECK(call != nullptr);
ICHECK_EQ(call->args.size(), 1);
PrimExpr arg = call->args[0];
int bits = arg.dtype().bits();
PrimExpr res;
if (arg.dtype().is_float() && (bits == 16 || bits == 32)) {
res = topi::fast_erf_float_expr(arg, bits);
} else {
LOG(FATAL) << "Unsupported type in Metal fast_erf";
}
return res;
}

} // namespace intrin

namespace legalize {
Expand Down
3 changes: 3 additions & 0 deletions src/target/intrin_rule.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ inline PrimExpr DispatchPureExtern(const PrimExpr& e) {
}
}

// Dispatch ERF to fast erf when it is not available.
PrimExpr DispatchFastErf(const PrimExpr& e);

} // namespace intrin
} // namespace codegen
} // namespace tvm
Expand Down
6 changes: 3 additions & 3 deletions src/target/source/codegen_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
*/
void RegisterHandleType(const VarNode* buf_var, DataType t);
// override
void PrintSSAAssign(const std::string& target, const std::string& src, DataType t) final;
void PrintSSAAssign(const std::string& target, const std::string& src, DataType t) override;
/*! \brief reserves common C keywords */
void ReserveKeywordsAsUnique();

Expand All @@ -281,10 +281,10 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
const Op& builtin_call_extern_ = builtin::call_extern();
const Op& builtin_call_pure_extern_ = builtin::call_pure_extern();
Integer constants_byte_alignment_ = 16;

private:
/*! \brief whether to print in SSA form */
bool print_ssa_form_{false};

private:
/*! \brief set of volatile buf access */
std::unordered_set<const VarNode*> volatile_buf_;
// deep comparison of PrimExpr
Expand Down
2 changes: 1 addition & 1 deletion src/target/source/codegen_metal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) {
// clear previous generated state.
this->InitFuncState(f);
// skip the first underscore, so SSA variable starts from _1
name_supply_->FreshName("_");
name_supply_->FreshName("v_");

// add to alloc buffer type.
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
Expand Down
3 changes: 2 additions & 1 deletion src/target/source/codegen_source_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ std::string CodeGenSourceBase::SSAGetID(std::string src, DataType t) {
}
}
SSAEntry e;
e.vid = name_supply_->FreshName("_");
// use v_ prefix so it works for most systems
e.vid = name_supply_->FreshName("v_");
e.scope_id = static_cast<int>(scope_mark_.size() - 1);
ssa_assign_map_[src] = e;
this->PrintIndent();
Expand Down
Loading

0 comments on commit 0a6b038

Please sign in to comment.