Skip to content

Commit

Permalink
fix variable op conversion to tosa error in ninja c1 (#8412)
Browse files Browse the repository at this point in the history
* pub

* move test iree resnet python script to oneflow_iree repo

* add bracket

* rename const_val to const_val_ and restore resnet.py test script

Co-authored-by: Shenghang Tsai <jackalcooper@gmail.com>
  • Loading branch information
howin98 and jackalcooper authored Jun 13, 2022
1 parent c289645 commit e8547b4
Show file tree
Hide file tree
Showing 8 changed files with 101 additions and 50 deletions.
4 changes: 4 additions & 0 deletions oneflow/api/common/variable_tensor_mgr.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ inline Maybe<void> FillVariableTensorMgr(
auto mgr = Global<VariableTensorMgr>::Get();
return mgr->Fill(variable_op_names, variable_tensors);
}
inline void ClearVariableTensorMgr() {
auto mgr = Global<VariableTensorMgr>::Get();
mgr->Clear();
}

inline std::tuple<std::vector<std::string>, std::vector<std::shared_ptr<one::Tensor>>>
DumpVariableTensorMgr() {
Expand Down
1 change: 1 addition & 0 deletions oneflow/api/python/framework/variable_tensor_mgr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ namespace oneflow {
ONEFLOW_API_PYBIND11_MODULE("", m) {
m.def("FillVariableTensorMgr", &FillVariableTensorMgr);
m.def("DumpVariableTensorMgr", &DumpVariableTensorMgr);
m.def("ClearVariableTensorMgr", &ClearVariableTensorMgr);
}

} // namespace oneflow
4 changes: 4 additions & 0 deletions oneflow/core/framework/variable_tensor_mgr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ VariableTensorMgr::Dump() {
return std::make_tuple(variable_op_names, variable_tensors);
}

void VariableTensorMgr::Clear() {
std::map<std::string, std::shared_ptr<one::Tensor>>().swap(variables_);
}

std::vector<std::string> VariableTensorMgr::DumpNames() {
std::vector<std::string> variable_op_names;
for (const auto& x : variables_) { variable_op_names.push_back(x.first); }
Expand Down
1 change: 1 addition & 0 deletions oneflow/core/framework/variable_tensor_mgr.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class VariableTensorMgr final {
const std::vector<std::shared_ptr<one::Tensor>>& variable_tensors);
std::tuple<std::vector<std::string>, std::vector<std::shared_ptr<one::Tensor>>> Dump();
std::vector<std::string> DumpNames();
void Clear();

private:
friend class Global<VariableTensorMgr>;
Expand Down
4 changes: 4 additions & 0 deletions oneflow/ir/include/OneFlow/OneFlowOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,10 @@ def LowerOneFlowToTosaPass : Pass<"lower-oneflow-to-tosa", "ModuleOp"> {
let summary = "";
let constructor = "mlir::oneflow::createLowerOneFlowToTosaPass()";
let dependentDialects = ["tosa::TosaDialect", "memref::MemRefDialect", "mlir::func::FuncDialect"];
let options = [
Option<"variableAsConstant", "variable-as-constant", "int", "0",
"convert variable op as const op of tosa">,
];
}

def MapSCFToGPUPass : Pass<"gpu-greedy-parallel-loop-mapping", "ModuleOp"> {
Expand Down
58 changes: 51 additions & 7 deletions oneflow/ir/lib/OneFlow/Conversion/OneFlowToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,10 @@ struct VariableOpLowering final : public OpConversionPattern<VariableOp> {
LogicalResult matchAndRewrite(VariableOp op, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
const auto mgr = ::oneflow::Global<::oneflow::VariableTensorMgr>::Get();
// decide whether call by python or not
if (!mgr) op->emitError("oneflow variable op doesn't support pure mlir file conversion");
if (!mgr) op->emitError("global variable tensor manager miss");

const auto tensor = mgr->Get(op.op_name().str());
if (!tensor) op->emitError("tensor is null");
const auto value = support::TensorToDenseElementsAttr(tensor, rewriter.getContext());
const auto output = op.output().getType();

Expand All @@ -180,6 +180,41 @@ struct VariableOpLowering final : public OpConversionPattern<VariableOp> {
}
};

struct VariableOpToConstLowering final : public OpConversionPattern<VariableOp> {
public:
VariableOpToConstLowering(TypeConverter& typeConverter, MLIRContext* context, int const_val)
: OpConversionPattern<VariableOp>(typeConverter, context), const_val_(const_val){};

using OpConversionPattern<VariableOp>::OpConversionPattern;
LogicalResult matchAndRewrite(VariableOp op, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
const auto output = op.output().getType();
const auto type = output.cast<ShapedType>().getElementType();

// TODO: more control about this scope with flag
if (type.isa<FloatType>()) {
const auto float_attr = rewriter.getFloatAttr(type, const_val_);
auto value = DenseElementsAttr::get(output, float_attr);

rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, output, value);
} else if (auto integerType = type.dyn_cast<IntegerType>()) {
const auto int_attr =
rewriter.getIntegerAttr(type, APInt(type.cast<IntegerType>().getWidth(), const_val_));
auto value = DenseElementsAttr::get(output, int_attr);

rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, output, value);
} else {
op->emitError(
"OneFlow variable op lower to TOSA const op only support integer and float value now");
}

return success();
}

private:
int const_val_;
};

struct CastOpLowering final : public OpConversionPattern<CastOp> {
public:
using OpConversionPattern<CastOp>::OpConversionPattern;
Expand Down Expand Up @@ -547,11 +582,20 @@ void OneFlowLoweringToTosaPass::runOnOperation() {
TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });
RewritePatternSet patterns(context);
patterns.add<CastOpLowering, ScalarMulByTensorOpLowering, ReluOpLowering, Conv2DOpLowering,
AvgPool2DOpLowering, FlattenOpLowering, Add2OpLowering, MaxPool2DOpLowering,
MatmulOpLowering, BroadcastAddOpLowering, JobLowering, ReturnOpLowering,
VariableOpLowering, InputOpLowering, OutputOpLowering, NormalizationOpLowering,
NormalizationInferenceOpLowering>(typeConverter, context);

const auto mgr = ::oneflow::Global<::oneflow::VariableTensorMgr>::Get();
// judge whether the pass is trigger by python through the existence of variable tensor manger
if (mgr) {
patterns.add<VariableOpLowering>(typeConverter, context);
} else {
patterns.add<VariableOpToConstLowering>(typeConverter, context, this->variableAsConstant);
}
patterns
.add<CastOpLowering, ScalarMulByTensorOpLowering, ReluOpLowering, Conv2DOpLowering,
AvgPool2DOpLowering, FlattenOpLowering, Add2OpLowering, MaxPool2DOpLowering,
MatmulOpLowering, BroadcastAddOpLowering, JobLowering, ReturnOpLowering, InputOpLowering,
OutputOpLowering, NormalizationOpLowering, NormalizationInferenceOpLowering>(
typeConverter, context);
if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) {
getOperation()->dump();
signalPassFailure();
Expand Down
58 changes: 26 additions & 32 deletions oneflow/ir/test/Frontend/test_iree_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,40 +24,37 @@
import numpy as np
import time


os.environ["ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION"] = "1"
os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1"
os.environ["ONEFLOW_MLIR_ENABLE_CODEGEN_FUSERS"] = "1"


def _test_iree_resnet_cpu(test_case):
model = resnet50(pretrained=True)
model.eval()

class GraphModule(flow.nn.Graph):
class GraphModuleForIree(flow.nn.Graph):
def __init__(self):
super().__init__()
self.model = model

def build(self, x):
return self.model(x)

func = Runner(GraphModule, return_numpy=True)
class GraphModuleForOFMLIR(flow.nn.Graph):
def __init__(self):
super().__init__()
self.model = model

def build(self, x):
return self.model(x)

func = Runner(GraphModuleForIree, return_numpy=True)
input = flow.ones([1, 3, 224, 224])
f = GraphModule()
for iter in range(3):
print("======== in cpu iter" + str(iter + 1))
f = GraphModuleForOFMLIR()
for iter in range(2):
iree_output = func(input)
start_time = time.time()
graph_output = f(input)
gap = time.time() - start_time
print("graph cost: " + str(gap))
graph_output = graph_output.cpu().detach().numpy()
rtol = np.abs((graph_output - iree_output) / iree_output)
np.set_printoptions(threshold=np.inf)
print(
np.transpose(
np.concatenate((graph_output, iree_output, rtol), axis=0), [1, 0]
)
)
# the rtol accumulate layer by layer
test_case.assertTrue(
np.allclose(iree_output, graph_output, rtol=1.0e-1, atol=1e-3)
Expand All @@ -68,32 +65,29 @@ def _test_iree_resnet_cuda(test_case):
model = resnet50(pretrained=True).cuda()
model.eval()

class GraphModule(flow.nn.Graph):
class GraphModuleForIree(flow.nn.Graph):
def __init__(self):
super().__init__()
self.model = model

def build(self, x):
return self.model(x)

func = Runner(GraphModule, return_numpy=True).cuda()
class GraphModuleForOFMLIR(flow.nn.Graph):
def __init__(self):
super().__init__()
self.model = model

def build(self, x):
return self.model(x)

func = Runner(GraphModuleForIree, return_numpy=True)
input = flow.ones([1, 3, 224, 224]).cuda()
f = GraphModule()
for iter in range(3):
print("======== in cuda iter" + str(iter + 1))
f = GraphModuleForOFMLIR()
for iter in range(2):
iree_output = func(input)
start_time = time.time()
graph_output = f(input)
gap = time.time() - start_time
print("graph cost: " + str(gap))
graph_output = graph_output.cpu().detach().numpy()
rtol = np.abs((graph_output - iree_output) / iree_output)
np.set_printoptions(threshold=np.inf)
print(
np.transpose(
np.concatenate((graph_output, iree_output, rtol), axis=0), [1, 0]
)
)
# the rtol accumulate layer by layer
test_case.assertTrue(
np.allclose(iree_output, graph_output, rtol=1.0e-1, atol=1e-3)
Expand Down
21 changes: 10 additions & 11 deletions python/oneflow/nn/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ def _shallow_repr(self):
return shallow_repr

def _ops_repr(self):
r"""Generate this graph's operators' string representation
r"""Generate this graph's operators' string representation
"""
if self._is_compiled:
conf = self._graph_proto.module_name2module_conf[
Expand Down Expand Up @@ -898,10 +898,9 @@ def __build_graph(self, *args, **kwargs):
)
enable_mlir_inference_opt = False
del os.environ["ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION"]
if enable_mlir_inference_opt:
oneflow._oneflow_internal.FillVariableTensorMgr(
state_op_names, self._state_tensor_tuple
)
oneflow._oneflow_internal.FillVariableTensorMgr(
state_op_names, self._state_tensor_tuple
)
# Complete the graph job proto
oneflow._oneflow_internal.CurJobBuildAndInferCtx_Complete()
# Save full graph job proto after job Complete for find real output blob shape and build it.
Expand Down Expand Up @@ -941,12 +940,11 @@ def __build_graph(self, *args, **kwargs):
self._c_nn_graph.register_output_op_names_and_tensors(
output_op_names, self._outputs_tensor_tuple
)
if enable_mlir_inference_opt:
(
state_op_names,
state_tensors,
) = oneflow._oneflow_internal.DumpVariableTensorMgr()
self._state_tensor_tuple = convert_to_tensor_tuple(state_tensors)
(
state_op_names,
state_tensors,
) = oneflow._oneflow_internal.DumpVariableTensorMgr()
self._state_tensor_tuple = convert_to_tensor_tuple(state_tensors)

self._c_nn_graph.register_variable_op_names_and_tensors(
state_op_names, self._state_tensor_tuple
Expand Down Expand Up @@ -1354,6 +1352,7 @@ def __del__(self):
# So it's safe to skip sync here.
return
oneflow._oneflow_internal.eager.Sync()
oneflow._oneflow_internal.ClearVariableTensorMgr()

def __ensure_input_tensors_contiguous(self, *args, **kwargs):
args_tree = ArgsTree((args, kwargs), False)
Expand Down

0 comments on commit e8547b4

Please sign in to comment.