Skip to content

Commit

Permalink
Fix mixed precision output type to original type (#11142)
Browse files Browse the repository at this point in the history
  • Loading branch information
gayatripk1 authored May 5, 2022
1 parent 5007033 commit eae836c
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 17 deletions.
60 changes: 53 additions & 7 deletions src/relay/transforms/to_mixed_precision.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
namespace tvm {
namespace relay {

TVM_REGISTER_PASS_CONFIG_OPTION("relay.ToMixedPrecision.keep_orig_output_dtype", Bool);
// A callable which hashes std::pair
struct pair_hash {
template <class T1, class T2>
Expand Down Expand Up @@ -105,6 +106,9 @@ class MixedPrecisionPass : public MixedModeMutator {
* encountered. Used for emitting warnings on missing ops in the pass.
*/
std::unordered_map<std::string, int> missing_ops_;
const RelayExprNode* root_;
std::vector<DataType> original_dtype_;
bool keep_orig_output_dtype_;

Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
/* If the accumulation dtype is in the attributes make a copy and mutate the field. */
Expand Down Expand Up @@ -278,8 +282,23 @@ class MixedPrecisionPass : public MixedModeMutator {
public:
using MixedModeMutator::VisitExpr_;

explicit MixedPrecisionPass(DataType mixed_precision_type = DataType::Float(16))
: MixedModeMutator(), mixed_precision_type_(mixed_precision_type) {
explicit MixedPrecisionPass(Expr base, bool keep_orig_output_dtype,
DataType mixed_precision_type = DataType::Float(16))
: MixedModeMutator(),
mixed_precision_type_(mixed_precision_type),
root_(Downcast<Function>(base)->body.get()),
keep_orig_output_dtype_(keep_orig_output_dtype) {
if (keep_orig_output_dtype_) {
if (root_->IsInstance<tvm::relay::TupleNode>()) {
const TupleTypeNode* tuple_type = (root_->checked_type_).as<TupleTypeNode>();
for (Type t : tuple_type->fields) {
const TensorTypeNode* tensor_type = t.as<TensorTypeNode>();
original_dtype_.push_back(tensor_type->dtype);
}
} else if (root_->IsInstance<tvm::relay::CallNode>()) {
original_dtype_.push_back((root_->checked_type_).as<TensorTypeNode>()->dtype);
}
}
if (!mixed_precision_type_.is_float() && !mixed_precision_type_.is_bfloat16()) {
LOG(FATAL) << "Only support IEEE floating point mixed precision types and bfloat16, but got "
<< mixed_precision_type_;
Expand Down Expand Up @@ -381,6 +400,11 @@ class MixedPrecisionPass : public MixedModeMutator {
if (accumulation_dtype != output_dtype) {
output = CastArg(output, GetType(output), output_dtype);
}
if (pre_call_node == root_ && keep_orig_output_dtype_) {
if (original_dtype_[0] != output_dtype) {
output = CastArg(output, GetType(output), original_dtype_[0]);
}
}
return output;
}

Expand All @@ -396,6 +420,21 @@ class MixedPrecisionPass : public MixedModeMutator {
Expr Rewrite_(const TupleNode* pre, const Expr& post) {
// The old checked type in the expression may not be valid so clear it
post->checked_type_ = Type(nullptr);
if (pre == root_ && keep_orig_output_dtype_) {
Array<Expr> new_expr;
bool all_same = true;
for (size_t i = 0; i < original_dtype_.size(); i++) {
Expr output_element = GetField(post, i);
Expr casted_element;
auto output_element_type = transform::InferTypeLocal(output_element);
casted_element = CastArg(output_element, output_element_type, original_dtype_[i]);
new_expr.push_back(casted_element);
all_same &= casted_element.same_as(output_element);
}
if (!all_same) {
return Tuple(new_expr);
}
}
return post;
}

Expand All @@ -421,11 +460,12 @@ class MixedPrecisionPass : public MixedModeMutator {
}

// To access map of ops not registered for error reporting
friend Expr ToMixedPrecision(const Expr& expr, const DataType& mixed_precision_type,
int missing_op_mode);
friend Expr ToMixedPrecision(const Expr& expr, bool keep_orig_output_dtype,
const DataType& mixed_precision_type, int missing_op_mode);
};

Expr ToMixedPrecision(const Expr& expr, const DataType& mixed_precision_type, int missing_op_mode) {
Expr ToMixedPrecision(const Expr& expr, bool keep_orig_output_dtype,
const DataType& mixed_precision_type, int missing_op_mode) {
/*
missing_op_mode:
Expand All @@ -436,7 +476,8 @@ Expr ToMixedPrecision(const Expr& expr, const DataType& mixed_precision_type, in
ICHECK(missing_op_mode >= 0 && missing_op_mode <= 2)
<< " missing_op_mode must be either 0, 1, or 2 got " << missing_op_mode;

MixedPrecisionPass converter = MixedPrecisionPass(mixed_precision_type);
MixedPrecisionPass converter =
MixedPrecisionPass(expr, keep_orig_output_dtype, mixed_precision_type);
auto result = converter.Mutate(expr);

for (auto it = converter.missing_ops_.begin();
Expand All @@ -460,7 +501,12 @@ namespace transform {
Pass ToMixedPrecision(DataType mixed_precision_type, int missing_op_mode) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(ToMixedPrecision(f, mixed_precision_type, missing_op_mode));
bool keep_orig_output_dtype = false;
keep_orig_output_dtype = pc->GetConfig("relay.ToMixedPrecision.keep_orig_output_dtype",
Bool(keep_orig_output_dtype))
.value();
return Downcast<Function>(
ToMixedPrecision(f, keep_orig_output_dtype, mixed_precision_type, missing_op_mode));
};
return CreateFunctionPass(pass_func, 0, "ToMixedPrecision", {});
}
Expand Down
39 changes: 29 additions & 10 deletions tests/python/relay/test_to_mixed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,31 @@ def verify_mixed_precision_output_close(
mixed_precision_dtype="float16",
rtol: float = 1e-3,
atol: float = 0,
keep_orig_output_dtype=False,
) -> tvm.runtime.Module:

mod = InferType()(mod)
result_fp32 = run_module(mod, mod_params)
fp16_mod = ToMixedPrecision(mixed_precision_dtype)(mod)
result_fp16 = run_module(fp16_mod, mod_params)

if not keep_orig_output_dtype:
fp16_mod = ToMixedPrecision(mixed_precision_dtype)(mod)
result_fp16 = run_module(fp16_mod, mod_params)
else:
with tvm.transform.PassContext(
config={"relay.ToMixedPrecision.keep_orig_output_dtype": True}
):
fp16_mod = ToMixedPrecision(mixed_precision_dtype)(mod)
result_fp16 = run_module(fp16_mod, mod_params)

# Ensure the results are close
for fp32, fp16 in zip(result_fp32, result_fp16):
np.testing.assert_allclose(fp32, fp16, rtol=rtol, atol=atol)

if keep_orig_output_dtype:
assert (
np.array(result_fp16).dtype == np.array(result_fp32).dtype
), "output type and original type mismatch"

return fp16_mod


Expand Down Expand Up @@ -117,16 +131,21 @@ def test_convert_single_conv():
"data": np.random.uniform(-1, 1, size=data_shape).astype("float32"),
"weight": np.random.uniform(-1, 1, size=weight_shape).astype("float32"),
}
fp16_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.01, rtol=1e-3)
fp16_mod = verify_mixed_precision_output_close(
mod, mod_params, atol=0.01, rtol=1e-3, keep_orig_output_dtype=True
)

expected_mod = tvm.IRModule.from_expr(
relay.nn.conv2d(
relay.cast(data, "float16"),
relay.cast(weight, "float16"),
strides=(1, 1),
padding=(1, 1),
out_dtype="float16",
),
relay.cast(
relay.nn.conv2d(
relay.cast(data, "float16"),
relay.cast(weight, "float16"),
strides=(1, 1),
padding=(1, 1),
out_dtype="float16",
),
"float32",
)
)
expected_mod = tvm.relay.transform.InferType()(expected_mod)

Expand Down

0 comments on commit eae836c

Please sign in to comment.