From 6e8e9c9d9074e52f61b2a9632474ba5a9eea3bb9 Mon Sep 17 00:00:00 2001 From: lixinqi Date: Thu, 12 May 2022 21:11:24 +0800 Subject: [PATCH 01/67] ThreadLocalGuard --- oneflow/core/common/thread_local_guard.h | 61 ++++++++++++++++++ .../core/common/thread_local_guard_test.cpp | 63 +++++++++++++++++++ 2 files changed, 124 insertions(+) create mode 100644 oneflow/core/common/thread_local_guard.h create mode 100644 oneflow/core/common/thread_local_guard_test.cpp diff --git a/oneflow/core/common/thread_local_guard.h b/oneflow/core/common/thread_local_guard.h new file mode 100644 index 00000000000..b538b476e2d --- /dev/null +++ b/oneflow/core/common/thread_local_guard.h @@ -0,0 +1,61 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_CORE_COMMON_THREAD_LOCAL_GUARD_H_ +#define ONEFLOW_CORE_COMMON_THREAD_LOCAL_GUARD_H_ + +#include +#include + +namespace oneflow { + +// Interfaces: +// - ThreadLocalGuard::CurrentValue() +// - ThreadLocalGuard::HasCurrentValue() +template +class ThreadLocalGuard; + +template<> +class ThreadLocalGuard { + public: + explicit ThreadLocalGuard(bool value) { + old_value_ = *MutThreadLocalValue(); + *MutThreadLocalValue() = int(value); + } + ~ThreadLocalGuard() { *MutThreadLocalValue() = old_value_; } + + static bool CurrentValue() { + int value = *MutThreadLocalValue(); + CHECK_GE(value, 0); + return value > 0; + } + + static bool HasCurrentValue() { return *MutThreadLocalValue() >= 0; } + + private: + static int* MutThreadLocalValue() { + static thread_local int value = -1; + return &value; + } + + // -1: not exists. + // 0: false. + // 1: true. + int old_value_; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_COMMON_THREAD_LOCAL_GUARD_H_ diff --git a/oneflow/core/common/thread_local_guard_test.cpp b/oneflow/core/common/thread_local_guard_test.cpp new file mode 100644 index 00000000000..e59daa54fd6 --- /dev/null +++ b/oneflow/core/common/thread_local_guard_test.cpp @@ -0,0 +1,63 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include +#include "oneflow/core/common/util.h" +#include "oneflow/core/common/thread_local_guard.h" + +namespace oneflow { +namespace test { + +template +void AssertCurrentValue(const T& value) { + ThreadLocalGuard guard(value); + ASSERT_TRUE(ThreadLocalGuard::HasCurrentValue()); + ASSERT_EQ(ThreadLocalGuard::CurrentValue(), value); +} + +template +void Assert(const T& value0, const T& value1) { + ASSERT_FALSE(ThreadLocalGuard::HasCurrentValue()); + { + ThreadLocalGuard guard(value0); + ASSERT_TRUE(ThreadLocalGuard::HasCurrentValue()); + } + { + ThreadLocalGuard guard(value0); + ASSERT_TRUE(ThreadLocalGuard::HasCurrentValue()); + ASSERT_EQ(ThreadLocalGuard::CurrentValue(), value0); + } + { + ThreadLocalGuard guard(value1); + ASSERT_TRUE(ThreadLocalGuard::HasCurrentValue()); + ASSERT_EQ(ThreadLocalGuard::CurrentValue(), value1); + } + { + ThreadLocalGuard guard(value0); + ASSERT_TRUE(ThreadLocalGuard::HasCurrentValue()); + ASSERT_EQ(ThreadLocalGuard::CurrentValue(), value0); + { + ThreadLocalGuard nested_guard(value1); + ASSERT_TRUE(ThreadLocalGuard::HasCurrentValue()); + ASSERT_EQ(ThreadLocalGuard::CurrentValue(), value1); + } + ASSERT_EQ(ThreadLocalGuard::CurrentValue(), value0); + } +} + +TEST(ThreadLocalGuard, bool) { Assert(true, false); } + +} // namespace test +} // namespace oneflow From 78cc1fc437da37261a7d1db94ee6bdcd7b5babf3 Mon Sep 17 00:00:00 2001 From: lixinqi Date: Thu, 7 Jul 2022 19:24:39 +0800 Subject: [PATCH 02/67] refactor EagerBlobObjectList --- oneflow/api/python/functional/python_frame.h | 1 + oneflow/core/autograd/autograd_engine.cpp | 5 +++++ oneflow/core/eager/call_context.h | 5 +---- .../critical_section_phy_instr_operand.h | 7 ------ oneflow/core/eager/eager_blob_object.h | 10 +++++++++ .../core/eager/lazy_job_phy_instr_operand.h | 7 ------ oneflow/core/framework/nn_graph.cpp | 22 ++++++++----------- .../core/framework/op_expr_grad_function.h | 6 +++++ .../eager_mirrored_op_interpreter.cpp | 17 +++++++++++++- .../op_interpreter/op_interpreter.cpp | 6 +++++ .../op_interpreter/op_interpreter_util.cpp | 4 ++++ oneflow/core/framework/tensor_tuple.h | 5 ++++- .../vm/touch_tensors_instruction_type.cpp | 2 +- .../core/vm/touch_tensors_instruction_type.h | 7 +++--- oneflow/core/vm/virtual_machine_engine.cpp | 12 +++++----- oneflow/user/kernels/stateful_opkernel.cpp | 2 +- tools/functional/generator.py | 1 + 17 files changed, 74 insertions(+), 45 deletions(-) diff --git a/oneflow/api/python/functional/python_frame.h b/oneflow/api/python/functional/python_frame.h index c6db38dac15..eb0153571ba 100644 --- a/oneflow/api/python/functional/python_frame.h +++ b/oneflow/api/python/functional/python_frame.h @@ -21,6 +21,7 @@ limitations under the License. #include "oneflow/api/python/functional/common.h" #include "oneflow/core/framework/op_interpreter/dispatch_frame.h" #include "oneflow/core/job/graph_scope_vars.h" +#include "oneflow/core/profiler/profiler.h" namespace oneflow { namespace one { diff --git a/oneflow/core/autograd/autograd_engine.cpp b/oneflow/core/autograd/autograd_engine.cpp index bf29b1c117f..fdea77172a2 100644 --- a/oneflow/core/autograd/autograd_engine.cpp +++ b/oneflow/core/autograd/autograd_engine.cpp @@ -28,6 +28,7 @@ limitations under the License. #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/framework/global_param_grad_sync_mode.h" #include "oneflow/core/common/container_util.h" +#include "oneflow/core/profiler/profiler.h" namespace oneflow { namespace one { @@ -396,6 +397,7 @@ Maybe GraphAutogradEngine::RunBackwardAndReturnInputsTensorGrad( Maybe GraphAutogradEngine::AddNode( const std::string& name, const std::shared_ptr& backward_fn, const TensorTuple& inputs, TensorTuple* outputs) { + OF_PROFILER_RANGE_PUSH("AddAccumulateFunctionNode"); // Firstly push function_node of tensor in stack which is leaf and requires_grad for (const std::shared_ptr& in_tensor : inputs) { if (in_tensor->is_leaf() && in_tensor->requires_grad()) { @@ -403,11 +405,14 @@ Maybe GraphAutogradEngine::AddNode( } } + OF_PROFILER_RANGE_POP(); + OF_PROFILER_RANGE_PUSH("set_grad_fn_node"); std::shared_ptr func_node = GraphFunctionNode::New(name, backward_fn, inputs, *outputs); for (const std::shared_ptr& out_tensor : *outputs) { out_tensor->set_grad_fn_node(func_node); } + OF_PROFILER_RANGE_POP(); return func_node; } diff --git a/oneflow/core/eager/call_context.h b/oneflow/core/eager/call_context.h index 0e7058c0292..ac47f54045f 100644 --- a/oneflow/core/eager/call_context.h +++ b/oneflow/core/eager/call_context.h @@ -21,6 +21,7 @@ limitations under the License. #include "oneflow/core/framework/op_interpreter.h" #include "oneflow/core/common/shape_view.h" #include "oneflow/core/common/stride.h" +#include "oneflow/core/common/small_vector.h" namespace oneflow { @@ -29,10 +30,6 @@ namespace one { class StatefulLocalOpKernel; class ConsistentTensorInferResult; -using EagerBlobObjectList = std::vector>; -using EagerBlobObjectListPtr = - std::shared_ptr>>; - } // namespace one class DeviceCtx; diff --git a/oneflow/core/eager/critical_section_phy_instr_operand.h b/oneflow/core/eager/critical_section_phy_instr_operand.h index eac77d38c41..26207bfa31b 100644 --- a/oneflow/core/eager/critical_section_phy_instr_operand.h +++ b/oneflow/core/eager/critical_section_phy_instr_operand.h @@ -24,13 +24,6 @@ limitations under the License. namespace oneflow { -namespace one { - -using EagerBlobObjectListPtr = - std::shared_ptr>>; - -} - namespace vm { class Stream; diff --git a/oneflow/core/eager/eager_blob_object.h b/oneflow/core/eager/eager_blob_object.h index 22cc9aaf7dd..7980ef8fc2b 100644 --- a/oneflow/core/eager/eager_blob_object.h +++ b/oneflow/core/eager/eager_blob_object.h @@ -222,6 +222,16 @@ class EagerBlobObject final : public user_op::Tensor, }; } // namespace vm + +namespace one { + +constexpr static int kEagerBlobObjectListReservedSize = 4; +using EagerBlobObjectList = + small_vector, kEagerBlobObjectListReservedSize>; +using EagerBlobObjectListPtr = std::shared_ptr; + +} // namespace one + } // namespace oneflow #endif // ONEFLOW_CORE_EAGER_EAGER_BLOB_OBJECT_H_ diff --git a/oneflow/core/eager/lazy_job_phy_instr_operand.h b/oneflow/core/eager/lazy_job_phy_instr_operand.h index 2a231fdd0d7..e7308a61b8c 100644 --- a/oneflow/core/eager/lazy_job_phy_instr_operand.h +++ b/oneflow/core/eager/lazy_job_phy_instr_operand.h @@ -25,13 +25,6 @@ limitations under the License. namespace oneflow { -namespace one { - -using EagerBlobObjectListPtr = - std::shared_ptr>>; - -} - namespace vm { class LaunchLazyJobPhyInstrOperand final : public PhyInstrOperand { diff --git a/oneflow/core/framework/nn_graph.cpp b/oneflow/core/framework/nn_graph.cpp index e38ca274799..e20efe83659 100644 --- a/oneflow/core/framework/nn_graph.cpp +++ b/oneflow/core/framework/nn_graph.cpp @@ -446,7 +446,7 @@ Maybe NNGraph::GetVariableRealBlobAfterSyncPlan() { } // Initialize or check mem_ptr_for_allocation_computation_pipelining by TouchTensors instruction. JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe { - auto eager_blob_objects = std::make_shared>>(); + auto eager_blob_objects = std::make_shared(); for (const auto& pair : variable_op_name2eager_blob_object_) { eager_blob_objects->push_back(pair.second->shared_from_this()); } @@ -508,7 +508,7 @@ void NNGraph::CloseRuntimeBuffers() { namespace { -Maybe MakeEagerBlobObjectList(std::vector>* blob_list, +Maybe MakeEagerBlobObjectList(one::EagerBlobObjectList* blob_list, const one::TensorTuple& tensor_list) { blob_list->reserve(tensor_list.size()); for (const auto& tensor : tensor_list) { @@ -549,21 +549,18 @@ Maybe RunLazyNNGraph(const one::TensorTuple& inputs, const one::TensorTupl CHECK_OR_RETURN(nn_graph->outputs_tensor_meta_str().at(i) == *JUST(GetTensorMetaString(outputs.at(i)))); } - std::vector> input_blobs; - std::vector> output_blobs; - std::vector> var_blobs; + one::EagerBlobObjectList input_blobs; + one::EagerBlobObjectList output_blobs; + one::EagerBlobObjectList var_blobs; JUST(MakeEagerBlobObjectList(&input_blobs, inputs)); JUST(MakeEagerBlobObjectList(&output_blobs, outputs)); JUST(MakeEagerBlobObjectList(&var_blobs, parameters)); const auto& input_blob_list_ptr = - std::make_shared>>( - std::move(input_blobs)); + std::make_shared(std::move(input_blobs)); const auto& output_blob_list_ptr = - std::make_shared>>( - std::move(output_blobs)); + std::make_shared(std::move(output_blobs)); const auto& var_blob_list_ptr = - std::make_shared>>( - std::move(var_blobs)); + std::make_shared(std::move(var_blobs)); JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe { return builder->LaunchLazyJob(input_blob_list_ptr, output_blob_list_ptr, var_blob_list_ptr, nn_graph); @@ -573,8 +570,7 @@ Maybe RunLazyNNGraph(const one::TensorTuple& inputs, const one::TensorTupl Maybe SoftSyncNNGraphBuffers(const one::TensorTuple& buffers, const std::shared_ptr& nn_graph) { - const auto& eager_blob_objects = - std::make_shared>>(); + const auto& eager_blob_objects = std::make_shared(); JUST(MakeEagerBlobObjectList(eager_blob_objects.get(), buffers)); JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe { return builder->SoftSyncNNGraphBuffers(eager_blob_objects, nn_graph); diff --git a/oneflow/core/framework/op_expr_grad_function.h b/oneflow/core/framework/op_expr_grad_function.h index 02dacf23ebc..969822acf7f 100644 --- a/oneflow/core/framework/op_expr_grad_function.h +++ b/oneflow/core/framework/op_expr_grad_function.h @@ -20,6 +20,7 @@ limitations under the License. #include "oneflow/core/autograd/autograd_captured_tensor.h" #include "oneflow/core/common/auto_registration_factory.h" #include "oneflow/core/framework/op_interpreter.h" +#include "oneflow/core/profiler/profiler.h" namespace oneflow { namespace one { @@ -96,14 +97,19 @@ class OpExprGradFunction : public OpExprGradFunctionIf { CHECK_NOTNULL_OR_RETURN(state); // Convert outputs from `Tensor` to `AutogradCapturedTensor` to avoid // circular reference between `Tensor` and `FunctionNode`. + OF_PROFILER_RANGE_PUSH("init inputs"); TensorTuple captured_inputs(inputs.size()); for (int i = 0; i < inputs.size(); ++i) { captured_inputs[i] = JUST(AutogradCapturedTensor::MakeTensor(inputs.at(i))); } + OF_PROFILER_RANGE_POP(); + OF_PROFILER_RANGE_PUSH("init outputs"); TensorTuple captured_outputs(outputs.size()); for (int i = 0; i < outputs.size(); ++i) { captured_outputs[i] = JUST(AutogradCapturedTensor::MakeTensor(outputs.at(i))); } + OF_PROFILER_RANGE_POP(); + OF_PROFILER_RANGE_GUARD("Capture"); return Capture(state, captured_inputs, captured_outputs, interp_ctx); } diff --git a/oneflow/core/framework/op_interpreter/eager_mirrored_op_interpreter.cpp b/oneflow/core/framework/op_interpreter/eager_mirrored_op_interpreter.cpp index 357d563acaa..b5b08c8d833 100644 --- a/oneflow/core/framework/op_interpreter/eager_mirrored_op_interpreter.cpp +++ b/oneflow/core/framework/op_interpreter/eager_mirrored_op_interpreter.cpp @@ -39,6 +39,7 @@ limitations under the License. #include "oneflow/core/framework/id_util.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/rpc/include/global_process_ctx.h" +#include "oneflow/core/profiler/profiler.h" namespace oneflow { namespace one { @@ -86,6 +87,8 @@ std::vector* ThreadLocalDefaultOutputMutTensorMetas(int64_t size) { Maybe NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& inputs, const Symbol& default_device, TensorTuple* outputs, const OpExprInterpContext& ctx) { + OF_PROFILER_RANGE_GUARD("NaiveInterpret"); + OF_PROFILER_RANGE_PUSH("init inputs"); const auto& attrs = ctx.attrs; std::shared_ptr input_eager_blob_objects = std::make_shared(inputs.size()); @@ -100,6 +103,8 @@ Maybe NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& in } input_eager_blob_objects->at(i) = JUST(inputs.at(i)->eager_blob_object()); } + OF_PROFILER_RANGE_POP(); + OF_PROFILER_RANGE_PUSH("init outputs"); std::shared_ptr output_eager_blob_objects = std::make_shared(outputs->size()); auto* output_tensor_metas = ThreadLocalDefaultOutputMutTensorMetas(outputs->size()); @@ -117,6 +122,8 @@ Maybe NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& in Symbol stream; bool need_check_mem_case = true; + OF_PROFILER_RANGE_POP(); + OF_PROFILER_RANGE_PUSH("infer devices"); // Infer devices if (!user_op_expr.has_device_and_stream_infer_fn()) { stream = JUST(GetDefaultStreamByDevice(default_device)); @@ -129,6 +136,8 @@ Maybe NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& in stream = JUST(user_op_expr.InferDeviceAndStream(attrs, inputs, outputs)); } + OF_PROFILER_RANGE_POP(); + OF_PROFILER_RANGE_PUSH("infer shapes and dtypes"); // Infer shapes and dtypes const auto& device_tag = stream->device()->type(); JUST(user_op_expr.InferPhysicalTensorDesc( @@ -142,6 +151,8 @@ Maybe NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& in return output_tensor_metas->at(i); })); + OF_PROFILER_RANGE_POP(); + OF_PROFILER_RANGE_PUSH("init output eager_blob_objects"); for (int i = 0; i < output_eager_blob_objects->size(); i++) { auto* tensor_impl = JUST(TensorImpl4Tensor(outputs->at(i))); if (!output_eager_blob_objects->at(i)) { @@ -166,16 +177,20 @@ Maybe NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& in } } + OF_PROFILER_RANGE_POP(); + OF_PROFILER_RANGE_PUSH("init opkernel"); const auto& kernel = JUST(user_op_expr.MutKernel4Stream(stream)); kernel->set_need_check_mem_case(need_check_mem_case); for (int64_t index : kernel->output_tuple_indexes4mut2_obns()) { output_eager_blob_objects->at(index)->set_is_shape_synced(false); } - + OF_PROFILER_RANGE_POP(); + OF_PROFILER_RANGE_PUSH("PhysicalRun"); JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe { return builder->Call(kernel, input_eager_blob_objects, output_eager_blob_objects, ctx, stream); })); + OF_PROFILER_RANGE_POP(); return Maybe::Ok(); } diff --git a/oneflow/core/framework/op_interpreter/op_interpreter.cpp b/oneflow/core/framework/op_interpreter/op_interpreter.cpp index 6dea92f954c..6b669bc3669 100644 --- a/oneflow/core/framework/op_interpreter/op_interpreter.cpp +++ b/oneflow/core/framework/op_interpreter/op_interpreter.cpp @@ -23,6 +23,7 @@ limitations under the License. #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/job/lazy_mode.h" +#include "oneflow/core/profiler/profiler.h" namespace oneflow { namespace one { @@ -112,6 +113,7 @@ Maybe AutogradInterpreter::Apply(const OpExpr& op_expr, const TensorTuple& // Lazy mode will construct backward compute graph in passes, so disable autograd if lazy mode. std::shared_ptr grad_closure(nullptr); if (requires_grad && !LazyMode::is_enabled()) { + OF_PROFILER_RANGE_PUSH("autograd.GetOrCreateOpGradClosure"); grad_closure = JUST(op_expr.GetOrCreateOpGradClosure()); auto backward_fn = std::make_shared(); backward_fn->body = [=](const TensorTuple& out_grads, TensorTuple* in_grads, @@ -121,8 +123,11 @@ Maybe AutogradInterpreter::Apply(const OpExpr& op_expr, const TensorTuple& return Maybe::Ok(); }; backward_fn->status = [=]() { return grad_closure->state()->SavedTensors().size() > 0; }; + OF_PROFILER_RANGE_POP(); + OF_PROFILER_RANGE_PUSH("autograd.AddNode"); JUST(GetThreadLocalAutogradEngine()->AddNode(op_expr.op_type_name() + "_backward", backward_fn, *inputs_ptr, outputs)); + OF_PROFILER_RANGE_POP(); } // Update outputs autograd meta // Note: if requires_grad is True, we will create a new autograd meta for each output @@ -157,6 +162,7 @@ Maybe AutogradInterpreter::Apply(const OpExpr& op_expr, const TensorTuple& } if (requires_grad && !LazyMode::is_enabled()) { + OF_PROFILER_RANGE_GUARD("autograd.Capture"); // Capture inputs and outputs after `AddBackwardFuncPtr` because of that grad function // node has been attached to them. JUST(grad_closure->Capture(*inputs_ptr, *outputs, ctx)); diff --git a/oneflow/core/framework/op_interpreter/op_interpreter_util.cpp b/oneflow/core/framework/op_interpreter/op_interpreter_util.cpp index f9eff347004..e7533372a7c 100644 --- a/oneflow/core/framework/op_interpreter/op_interpreter_util.cpp +++ b/oneflow/core/framework/op_interpreter/op_interpreter_util.cpp @@ -25,6 +25,7 @@ limitations under the License. #include "oneflow/core/job/lazy_mode.h" #include "oneflow/core/job/job_build_and_infer_ctx_mgr.h" #include "oneflow/core/operator/operator.h" +#include "oneflow/core/profiler/profiler.h" namespace oneflow { namespace one { @@ -125,6 +126,7 @@ Maybe GetInterpreter(const TensorTuple& inputs, const OpExp template<> /* static */ Maybe OpInterpUtil::Dispatch( const OpExpr& op_expr, const TensorTuple& inputs, const OpExprInterpContext& ctx) { + OF_PROFILER_RANGE_GUARD("Dispatch"); auto outputs = std::make_shared(op_expr.output_size()); JUST(Dispatch(op_expr, inputs, outputs.get(), ctx)); return outputs; @@ -134,12 +136,14 @@ template<> /* static */ Maybe OpInterpUtil::Dispatch(const OpExpr& op_expr, const TensorTuple& inputs, const OpExprInterpContext& ctx) { + OF_PROFILER_RANGE_GUARD("Dispatch"); return JUST(Dispatch(op_expr, inputs, ctx))->at(0); } /* static */ Maybe OpInterpUtil::Dispatch(const OpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext& ctx) { + OF_PROFILER_RANGE_GUARD("Dispatch"); return JUST(GetInterpreter(inputs, ctx, op_expr))->Apply(op_expr, inputs, outputs, ctx); } diff --git a/oneflow/core/framework/tensor_tuple.h b/oneflow/core/framework/tensor_tuple.h index b996aa5b080..39f3f62d48b 100644 --- a/oneflow/core/framework/tensor_tuple.h +++ b/oneflow/core/framework/tensor_tuple.h @@ -19,13 +19,16 @@ limitations under the License. #include #include +#include "oneflow/core/common/small_vector.h" namespace oneflow { namespace one { class Tensor; -class TensorTuple final : public std::vector>, +constexpr static int kTensorTupleReservedSize = 4; + +class TensorTuple final : public small_vector, kTensorTupleReservedSize>, public std::enable_shared_from_this { public: // TensorTuple(const TensorTuple&) = delete; diff --git a/oneflow/core/vm/touch_tensors_instruction_type.cpp b/oneflow/core/vm/touch_tensors_instruction_type.cpp index 5004ddb0ed6..d59b605ac61 100644 --- a/oneflow/core/vm/touch_tensors_instruction_type.cpp +++ b/oneflow/core/vm/touch_tensors_instruction_type.cpp @@ -20,7 +20,7 @@ namespace oneflow { namespace vm { TouchTensorsPhyInstrOperand::TouchTensorsPhyInstrOperand( - const std::vector>& eager_blob_objects) + const one::EagerBlobObjectList& eager_blob_objects) : eager_blob_objects_(eager_blob_objects) { const auto& Insert = SetInserter(&input_dependences_); for (const auto& eager_blob_object : eager_blob_objects_) { diff --git a/oneflow/core/vm/touch_tensors_instruction_type.h b/oneflow/core/vm/touch_tensors_instruction_type.h index 9b259865688..e2ada6ab594 100644 --- a/oneflow/core/vm/touch_tensors_instruction_type.h +++ b/oneflow/core/vm/touch_tensors_instruction_type.h @@ -18,17 +18,16 @@ limitations under the License. #include "oneflow/core/vm/instruction_type.h" #include "oneflow/core/vm/phy_instr_operand.h" +#include "oneflow/core/eager/eager_blob_object.h" namespace oneflow { namespace vm { -class EagerBlobObject; class Instruction; class TouchTensorsPhyInstrOperand final : public PhyInstrOperand { public: - TouchTensorsPhyInstrOperand( - const std::vector>& eager_blob_objects); + TouchTensorsPhyInstrOperand(const one::EagerBlobObjectList& eager_blob_objects); const DependenceVector& input_dependences() const override { return input_dependences_; } const DependenceVector& output_dependences() const override { @@ -41,7 +40,7 @@ class TouchTensorsPhyInstrOperand final : public PhyInstrOperand { } private: - std::vector> eager_blob_objects_; + one::EagerBlobObjectList eager_blob_objects_; DependenceVector input_dependences_; }; diff --git a/oneflow/core/vm/virtual_machine_engine.cpp b/oneflow/core/vm/virtual_machine_engine.cpp index 117bb2022ac..91554c8b1a2 100644 --- a/oneflow/core/vm/virtual_machine_engine.cpp +++ b/oneflow/core/vm/virtual_machine_engine.cpp @@ -346,12 +346,12 @@ void VirtualMachineEngine::DispatchInstruction(Instruction* instruction, // Returns true if old scheduler_pending_instruction_list is empty Maybe VirtualMachineEngine::Receive(InstructionList* compute_instruction_list) { - OF_PROFILER_RANGE_GUARD("vm:Receive"); -#ifdef OF_ENABLE_PROFILER - INTRUSIVE_UNSAFE_FOR_EACH_PTR(compute_instruction, compute_instruction_list) { - OF_PROFILER_RANGE_GUARD(compute_instruction->DebugName()); - } -#endif + // OF_PROFILER_RANGE_GUARD("vm:Receive"); + // #ifdef OF_ENABLE_PROFILER + // INTRUSIVE_UNSAFE_FOR_EACH_PTR(compute_instruction, compute_instruction_list) { + // OF_PROFILER_RANGE_GUARD(compute_instruction->DebugName()); + // } + // #endif bool old_list_empty = mut_pending_instruction_list()->MoveFrom(compute_instruction_list); return old_list_empty; } diff --git a/oneflow/user/kernels/stateful_opkernel.cpp b/oneflow/user/kernels/stateful_opkernel.cpp index 7fbf2eced47..2c6f4d7faeb 100644 --- a/oneflow/user/kernels/stateful_opkernel.cpp +++ b/oneflow/user/kernels/stateful_opkernel.cpp @@ -785,7 +785,7 @@ size_t StatefulOpKernel::InferTmpSize(eager::CallContext* call_ctx, Maybe StatefulOpKernel::ChooseOpKernel(eager::CallContext* call_ctx, const user_op::OpKernel** user_opkernel, bool* need_temp_storage) { - OF_PROFILER_RANGE_GUARD("ChooseOpKernel"); + // OF_PROFILER_RANGE_GUARD("ChooseOpKernel"); DataType primary_dtype = kInvalidDataType; const auto& inputs = call_ctx->inputs(); const auto& outputs = call_ctx->outputs(); diff --git a/tools/functional/generator.py b/tools/functional/generator.py index 6a0054a655a..792cd6651d2 100644 --- a/tools/functional/generator.py +++ b/tools/functional/generator.py @@ -535,6 +535,7 @@ def generate_pybind_for_python( name ) schema_fmt += " HANDLE_ERRORS\n" + schema_fmt += ' OF_PROFILER_RANGE_GUARD("{0}");\n'.format(name) schema_fmt += " PythonFrameGuard pf;\n" schema_fmt += ' static PythonArgParser<{0}> parser("{1}");\n'.format( ", ".join(schema_types), name From 720262ff29eb3952d0b54f6a275aafc4bcb18a31 Mon Sep 17 00:00:00 2001 From: lixinqi Date: Fri, 8 Jul 2022 11:33:32 +0800 Subject: [PATCH 03/67] op_args_reserved_size --- oneflow/core/common/op_args_reserved_size.h | 25 +++++++++++++++++++++ oneflow/core/eager/eager_blob_object.h | 5 ++--- oneflow/core/framework/tensor_tuple.h | 5 ++--- 3 files changed, 29 insertions(+), 6 deletions(-) create mode 100644 oneflow/core/common/op_args_reserved_size.h diff --git a/oneflow/core/common/op_args_reserved_size.h b/oneflow/core/common/op_args_reserved_size.h new file mode 100644 index 00000000000..83c97e03b82 --- /dev/null +++ b/oneflow/core/common/op_args_reserved_size.h @@ -0,0 +1,25 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_CORE_COMMON_OP_ARGS_RESERVED_SIZE_H_ +#define ONEFLOW_CORE_COMMON_OP_ARGS_RESERVED_SIZE_H_ + +namespace oneflow { + +constexpr static int kOpArgsReservedSize = 4; + +} + +#endif // ONEFLOW_CORE_COMMON_OP_ARGS_RESERVED_SIZE_H_ diff --git a/oneflow/core/eager/eager_blob_object.h b/oneflow/core/eager/eager_blob_object.h index 7980ef8fc2b..9fab1632bf9 100644 --- a/oneflow/core/eager/eager_blob_object.h +++ b/oneflow/core/eager/eager_blob_object.h @@ -18,6 +18,7 @@ limitations under the License. #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/optional.h" +#include "oneflow/core/common/op_args_reserved_size.h" #include "oneflow/core/eager/local_dep_object.h" #include "oneflow/core/device/device_context.h" #include "oneflow/core/memory/memory_allocator.h" @@ -225,9 +226,7 @@ class EagerBlobObject final : public user_op::Tensor, namespace one { -constexpr static int kEagerBlobObjectListReservedSize = 4; -using EagerBlobObjectList = - small_vector, kEagerBlobObjectListReservedSize>; +using EagerBlobObjectList = small_vector, kOpArgsReservedSize>; using EagerBlobObjectListPtr = std::shared_ptr; } // namespace one diff --git a/oneflow/core/framework/tensor_tuple.h b/oneflow/core/framework/tensor_tuple.h index 39f3f62d48b..51b8c947f8f 100644 --- a/oneflow/core/framework/tensor_tuple.h +++ b/oneflow/core/framework/tensor_tuple.h @@ -20,15 +20,14 @@ limitations under the License. #include #include #include "oneflow/core/common/small_vector.h" +#include "oneflow/core/common/op_args_reserved_size.h" namespace oneflow { namespace one { class Tensor; -constexpr static int kTensorTupleReservedSize = 4; - -class TensorTuple final : public small_vector, kTensorTupleReservedSize>, +class TensorTuple final : public small_vector, kOpArgsReservedSize>, public std::enable_shared_from_this { public: // TensorTuple(const TensorTuple&) = delete; From 75ba7b6d3ea1600163b681c9c1cdff666b705bea Mon Sep 17 00:00:00 2001 From: lixinqi Date: Fri, 8 Jul 2022 12:19:22 +0800 Subject: [PATCH 04/67] remove useless comments --- oneflow/core/vm/virtual_machine_engine.cpp | 6 ------ oneflow/user/kernels/stateful_opkernel.cpp | 1 - 2 files changed, 7 deletions(-) diff --git a/oneflow/core/vm/virtual_machine_engine.cpp b/oneflow/core/vm/virtual_machine_engine.cpp index 1383a3abdf8..adc15be5f65 100644 --- a/oneflow/core/vm/virtual_machine_engine.cpp +++ b/oneflow/core/vm/virtual_machine_engine.cpp @@ -344,12 +344,6 @@ void VirtualMachineEngine::DispatchInstruction(Instruction* instruction, // Returns true if old scheduler_pending_instruction_list is empty Maybe VirtualMachineEngine::Receive(InstructionList* compute_instruction_list) { - // OF_PROFILER_RANGE_GUARD("vm:Receive"); - // #ifdef OF_ENABLE_PROFILER - // INTRUSIVE_UNSAFE_FOR_EACH_PTR(compute_instruction, compute_instruction_list) { - // OF_PROFILER_RANGE_GUARD(compute_instruction->DebugName()); - // } - // #endif bool old_list_empty = mut_pending_instruction_list()->MoveFrom(compute_instruction_list); return old_list_empty; } diff --git a/oneflow/user/kernels/stateful_opkernel.cpp b/oneflow/user/kernels/stateful_opkernel.cpp index 2c6f4d7faeb..4d4cb15f916 100644 --- a/oneflow/user/kernels/stateful_opkernel.cpp +++ b/oneflow/user/kernels/stateful_opkernel.cpp @@ -785,7 +785,6 @@ size_t StatefulOpKernel::InferTmpSize(eager::CallContext* call_ctx, Maybe StatefulOpKernel::ChooseOpKernel(eager::CallContext* call_ctx, const user_op::OpKernel** user_opkernel, bool* need_temp_storage) { - // OF_PROFILER_RANGE_GUARD("ChooseOpKernel"); DataType primary_dtype = kInvalidDataType; const auto& inputs = call_ctx->inputs(); const auto& outputs = call_ctx->outputs(); From 9a40b2345ac022d84271976fb233ffa3aaf162bc Mon Sep 17 00:00:00 2001 From: lixinqi Date: Sat, 9 Jul 2022 18:06:56 +0800 Subject: [PATCH 05/67] rename one::EagerBlobObjectList* to vm::EagerBlobObject* --- oneflow/core/eager/call_context.h | 12 +++++------ .../critical_section_phy_instr_operand.h | 10 ++++----- oneflow/core/eager/eager_blob_object.h | 6 +----- .../core/eager/lazy_job_phy_instr_operand.h | 4 ++-- .../core/eager/op_call_phy_instr_operand.cpp | 2 +- .../core/eager/op_call_phy_instr_operand.h | 6 +++--- .../core/framework/instructions_builder.cpp | 21 +++++++++---------- oneflow/core/framework/instructions_builder.h | 20 +++++++++--------- oneflow/core/framework/nn_graph.cpp | 18 ++++++++-------- .../eager_consistent_op_interpreter.cpp | 8 +++---- .../eager_local_op_interpreter.cpp | 10 ++++----- .../vm/touch_tensors_instruction_type.cpp | 2 +- .../core/vm/touch_tensors_instruction_type.h | 4 ++-- 13 files changed, 59 insertions(+), 64 deletions(-) diff --git a/oneflow/core/eager/call_context.h b/oneflow/core/eager/call_context.h index ac47f54045f..b709771195e 100644 --- a/oneflow/core/eager/call_context.h +++ b/oneflow/core/eager/call_context.h @@ -71,8 +71,8 @@ class TmpTensor final : public user_op::Tensor { class CallContext { public: CallContext( - ComposedAttrMap&& composed_attrs, const one::EagerBlobObjectListPtr& inputs, - const one::EagerBlobObjectListPtr& outputs, + ComposedAttrMap&& composed_attrs, const vm::EagerBlobObjectListPtr& inputs, + const vm::EagerBlobObjectListPtr& outputs, const std::shared_ptr& consistent_tensor_infer_result, const one::OpExprInterpContext& op_interp_ctx, const std::shared_ptr& mem_case) : composed_attrs_(std::move(composed_attrs)), @@ -85,8 +85,8 @@ class CallContext { ~CallContext() = default; const ComposedAttrMap& composed_attrs() const { return composed_attrs_; } - const one::EagerBlobObjectListPtr& inputs() const { return inputs_; } - const one::EagerBlobObjectListPtr& outputs() const { return outputs_; } + const vm::EagerBlobObjectListPtr& inputs() const { return inputs_; } + const vm::EagerBlobObjectListPtr& outputs() const { return outputs_; } const std::shared_ptr& consistent_tensor_infer_result() const { return consistent_tensor_infer_result_; @@ -96,8 +96,8 @@ class CallContext { private: const ComposedAttrMap composed_attrs_; - const one::EagerBlobObjectListPtr inputs_; - const one::EagerBlobObjectListPtr outputs_; + const vm::EagerBlobObjectListPtr inputs_; + const vm::EagerBlobObjectListPtr outputs_; const std::shared_ptr consistent_tensor_infer_result_; const one::OpExprInterpContext op_interp_ctx_; TmpTensor tmp_tensor_; diff --git a/oneflow/core/eager/critical_section_phy_instr_operand.h b/oneflow/core/eager/critical_section_phy_instr_operand.h index d0ec63397d5..93480eaa78d 100644 --- a/oneflow/core/eager/critical_section_phy_instr_operand.h +++ b/oneflow/core/eager/critical_section_phy_instr_operand.h @@ -39,7 +39,7 @@ class CriticalSectionBeginPhyInstrOperand : public PhyInstrOperand { explicit CriticalSectionBeginPhyInstrOperand( const std::shared_ptr& nn_graph, - const one::EagerBlobObjectListPtr& eager_blob_objects, + const vm::EagerBlobObjectListPtr& eager_blob_objects, const std::shared_ptr>>& op_name2end_event_record, vm::Stream* vm_stream) @@ -49,7 +49,7 @@ class CriticalSectionBeginPhyInstrOperand : public PhyInstrOperand { vm_stream_(vm_stream) {} const std::shared_ptr& nn_graph() const { return nn_graph_; } - const one::EagerBlobObjectListPtr& eager_blob_objects() const { return eager_blob_objects_; } + const vm::EagerBlobObjectListPtr& eager_blob_objects() const { return eager_blob_objects_; } void ForEachDependence(const std::function&) const; @@ -74,7 +74,7 @@ class CriticalSectionBeginPhyInstrOperand : public PhyInstrOperand { protected: std::shared_ptr nn_graph_; - one::EagerBlobObjectListPtr eager_blob_objects_; + vm::EagerBlobObjectListPtr eager_blob_objects_; std::shared_ptr>> op_name2end_event_record_; HashMap op_name2interface_index_; @@ -85,7 +85,7 @@ class InputCriticalSectionBeginPhyInstrOperand final : public CriticalSectionBeg public: InputCriticalSectionBeginPhyInstrOperand( const std::shared_ptr& nn_graph, - const one::EagerBlobObjectListPtr& eager_blob_objects, + const vm::EagerBlobObjectListPtr& eager_blob_objects, const std::shared_ptr>>& op_name2end_event_record, vm::Stream* vm_stream) @@ -142,7 +142,7 @@ class OutputCriticalSectionBeginPhyInstrOperand final : public CriticalSectionBe public: OutputCriticalSectionBeginPhyInstrOperand( const std::shared_ptr& nn_graph, - const one::EagerBlobObjectListPtr& eager_blob_objects, + const vm::EagerBlobObjectListPtr& eager_blob_objects, const std::shared_ptr>>& op_name2end_event_record, vm::Stream* vm_stream) diff --git a/oneflow/core/eager/eager_blob_object.h b/oneflow/core/eager/eager_blob_object.h index 9fab1632bf9..45e6569d3d4 100644 --- a/oneflow/core/eager/eager_blob_object.h +++ b/oneflow/core/eager/eager_blob_object.h @@ -222,14 +222,10 @@ class EagerBlobObject final : public user_op::Tensor, std::unique_ptr blob_; }; -} // namespace vm - -namespace one { - using EagerBlobObjectList = small_vector, kOpArgsReservedSize>; using EagerBlobObjectListPtr = std::shared_ptr; -} // namespace one +} // namespace vm } // namespace oneflow diff --git a/oneflow/core/eager/lazy_job_phy_instr_operand.h b/oneflow/core/eager/lazy_job_phy_instr_operand.h index 7652c2b6166..809dbfc71e7 100644 --- a/oneflow/core/eager/lazy_job_phy_instr_operand.h +++ b/oneflow/core/eager/lazy_job_phy_instr_operand.h @@ -34,7 +34,7 @@ class LaunchLazyJobPhyInstrOperand final : public PhyInstrOperand { ~LaunchLazyJobPhyInstrOperand() override = default; LaunchLazyJobPhyInstrOperand(const std::shared_ptr& nn_graph, - const one::EagerBlobObjectListPtr& param_blob_objects) + const vm::EagerBlobObjectListPtr& param_blob_objects) : nn_graph_(nn_graph), param_blob_objects_(param_blob_objects), input_dependences_(), @@ -62,7 +62,7 @@ class LaunchLazyJobPhyInstrOperand final : public PhyInstrOperand { private: std::shared_ptr nn_graph_; - one::EagerBlobObjectListPtr param_blob_objects_; + vm::EagerBlobObjectListPtr param_blob_objects_; DependenceVector input_dependences_; DependenceVector output_dependences_; }; diff --git a/oneflow/core/eager/op_call_phy_instr_operand.cpp b/oneflow/core/eager/op_call_phy_instr_operand.cpp index 4ad32b8752d..0a334b55dad 100644 --- a/oneflow/core/eager/op_call_phy_instr_operand.cpp +++ b/oneflow/core/eager/op_call_phy_instr_operand.cpp @@ -24,7 +24,7 @@ namespace vm { OpCallPhyInstrOperand::OpCallPhyInstrOperand( vm::Stream* vm_stream, const std::shared_ptr& opkernel, - const one::EagerBlobObjectListPtr& inputs, const one::EagerBlobObjectListPtr& outputs, + const vm::EagerBlobObjectListPtr& inputs, const vm::EagerBlobObjectListPtr& outputs, const std::shared_ptr& consistent_tensor_infer_result, const one::OpExprInterpContext& op_interp_ctx, const one::DevVmDepObjectConsumeMode dev_vm_dep_object_consume_mode) diff --git a/oneflow/core/eager/op_call_phy_instr_operand.h b/oneflow/core/eager/op_call_phy_instr_operand.h index 5c3940adac2..0557ae57e7a 100644 --- a/oneflow/core/eager/op_call_phy_instr_operand.h +++ b/oneflow/core/eager/op_call_phy_instr_operand.h @@ -49,8 +49,8 @@ class OpCallPhyInstrOperand final : public vm::PhyInstrOperand { } const one::StatefulOpKernel& opkernel() const { return *opkernel_; } - const one::EagerBlobObjectListPtr& inputs() const { return call_ctx_.inputs(); } - const one::EagerBlobObjectListPtr& outputs() const { return call_ctx_.outputs(); } + const vm::EagerBlobObjectListPtr& inputs() const { return call_ctx_.inputs(); } + const vm::EagerBlobObjectListPtr& outputs() const { return call_ctx_.outputs(); } const AttrMap& attrs() const { return call_ctx_.op_interp_ctx().attrs; } const one::OpExprInterpContext& op_interp_ctx() const { return call_ctx_.op_interp_ctx(); } const one::DevVmDepObjectConsumeMode& dev_vm_dep_object_consume_mode() const { @@ -93,7 +93,7 @@ class OpCallPhyInstrOperand final : public vm::PhyInstrOperand { friend struct OpCallInstructionUtil; OpCallPhyInstrOperand( vm::Stream* vm_stream, const std::shared_ptr& opkernel, - const one::EagerBlobObjectListPtr& inputs, const one::EagerBlobObjectListPtr& outputs, + const vm::EagerBlobObjectListPtr& inputs, const vm::EagerBlobObjectListPtr& outputs, const std::shared_ptr& consistent_tensor_infer_result, const one::OpExprInterpContext& op_interp_ctx, const one::DevVmDepObjectConsumeMode dev_vm_dep_object_consume_mode); diff --git a/oneflow/core/framework/instructions_builder.cpp b/oneflow/core/framework/instructions_builder.cpp index a9aba9ecfda..4beb8e095a6 100644 --- a/oneflow/core/framework/instructions_builder.cpp +++ b/oneflow/core/framework/instructions_builder.cpp @@ -130,9 +130,9 @@ Maybe InstructionsBuilder::MakeCriticalSectionEnd( // CriticalSectionBegin. // critical_section_callback is a non-blocking opkernel which notifies instruction // CriticalSectionEnd done. -Maybe InstructionsBuilder::LaunchLazyJob(const one::EagerBlobObjectListPtr& inputs, - const one::EagerBlobObjectListPtr& outputs, - const one::EagerBlobObjectListPtr& parameters, +Maybe InstructionsBuilder::LaunchLazyJob(const vm::EagerBlobObjectListPtr& inputs, + const vm::EagerBlobObjectListPtr& outputs, + const vm::EagerBlobObjectListPtr& parameters, const std::shared_ptr& nn_graph) { JUST(SoftSyncNNGraphBuffers(inputs, nn_graph)); JUST(SoftSyncNNGraphBuffers(outputs, nn_graph)); @@ -202,7 +202,7 @@ Maybe InstructionsBuilder::LaunchLazyJob(const one::EagerBlobObjectListPtr } Maybe InstructionsBuilder::SoftSyncNNGraphBuffers( - const one::EagerBlobObjectListPtr& eager_blob_objects, + const vm::EagerBlobObjectListPtr& eager_blob_objects, const std::shared_ptr& nn_graph) { const auto& stream = JUST(GetCriticalSectionStream()); JUST(SoftSyncStream(eager_blob_objects, stream)); @@ -359,16 +359,16 @@ Maybe InstructionsBuilder::BuildScopeByProtoStrSetter( } Maybe InstructionsBuilder::Call(const std::shared_ptr& opkernel, - const one::EagerBlobObjectListPtr& input_eager_blob_objects, - const one::EagerBlobObjectListPtr& output_eager_blob_objects, + const vm::EagerBlobObjectListPtr& input_eager_blob_objects, + const vm::EagerBlobObjectListPtr& output_eager_blob_objects, const one::OpExprInterpContext& ctx, Symbol stream) { return Call(opkernel, input_eager_blob_objects, output_eager_blob_objects, nullptr, ctx, stream); } Maybe InstructionsBuilder::Call( const std::shared_ptr& opkernel, - const one::EagerBlobObjectListPtr& input_eager_blob_objects, - const one::EagerBlobObjectListPtr& output_eager_blob_objects, + const vm::EagerBlobObjectListPtr& input_eager_blob_objects, + const vm::EagerBlobObjectListPtr& output_eager_blob_objects, const std::shared_ptr& consistent_tensor_infer_result, const one::OpExprInterpContext& ctx, Symbol stream) { JUST(SoftSyncStream(output_eager_blob_objects, stream)); @@ -426,8 +426,7 @@ Maybe InstructionsBuilder::ReleaseTensor( return Maybe::Ok(); } -Maybe InstructionsBuilder::TouchTensors( - const one::EagerBlobObjectListPtr& eager_blob_object) { +Maybe InstructionsBuilder::TouchTensors(const vm::EagerBlobObjectListPtr& eager_blob_object) { const auto& phy_instr_operand = std::make_shared(*eager_blob_object); Symbol device = JUST(Device::New("cpu")); @@ -440,7 +439,7 @@ Maybe InstructionsBuilder::TouchTensors( } Maybe InstructionsBuilder::SoftSyncStream( - const one::EagerBlobObjectListPtr& eager_blob_objects, Symbol stream) { + const vm::EagerBlobObjectListPtr& eager_blob_objects, Symbol stream) { SmallSet> last_used_streams; for (const auto& eager_blob_object : *eager_blob_objects) { const auto& opt_last_used_stream = eager_blob_object->last_used_stream(); diff --git a/oneflow/core/framework/instructions_builder.h b/oneflow/core/framework/instructions_builder.h index 4f68dbcf840..e832b0e7de2 100644 --- a/oneflow/core/framework/instructions_builder.h +++ b/oneflow/core/framework/instructions_builder.h @@ -55,13 +55,13 @@ class InstructionsBuilder : public std::enable_shared_from_this LaunchLazyJob(const one::EagerBlobObjectListPtr& inputs, - const one::EagerBlobObjectListPtr& outputs, - const one::EagerBlobObjectListPtr& parameters, + Maybe LaunchLazyJob(const vm::EagerBlobObjectListPtr& inputs, + const vm::EagerBlobObjectListPtr& outputs, + const vm::EagerBlobObjectListPtr& parameters, const std::shared_ptr& nn_graph); // soft sync for inputs/outputs buffers of NNGraph - Maybe SoftSyncNNGraphBuffers(const one::EagerBlobObjectListPtr& eager_blob_objects, + Maybe SoftSyncNNGraphBuffers(const vm::EagerBlobObjectListPtr& eager_blob_objects, const std::shared_ptr& nn_graph); Maybe GetJobConfSymbol(const JobConfigProto& job_conf); @@ -74,7 +74,7 @@ class InstructionsBuilder : public std::enable_shared_from_this ReleaseTensor(const std::shared_ptr& eager_blob_object); - Maybe TouchTensors(const one::EagerBlobObjectListPtr& eager_blob_object); + Maybe TouchTensors(const vm::EagerBlobObjectListPtr& eager_blob_object); template Maybe SyncAccessBlobByCallback(const T tensor, const std::shared_ptr& btb, @@ -118,19 +118,19 @@ class InstructionsBuilder : public std::enable_shared_from_this& StrSetter); Maybe Call(const std::shared_ptr& opkernel, - const one::EagerBlobObjectListPtr& input_eager_blob_objects, - const one::EagerBlobObjectListPtr& output_eager_blob_objects, + const vm::EagerBlobObjectListPtr& input_eager_blob_objects, + const vm::EagerBlobObjectListPtr& output_eager_blob_objects, const one::OpExprInterpContext& ctx, Symbol stream); Maybe Call( const std::shared_ptr& opkernel, - const one::EagerBlobObjectListPtr& input_eager_blob_objects, - const one::EagerBlobObjectListPtr& output_eager_blob_objects, + const vm::EagerBlobObjectListPtr& input_eager_blob_objects, + const vm::EagerBlobObjectListPtr& output_eager_blob_objects, const std::shared_ptr& consistent_tensor_infer_result, const one::OpExprInterpContext& ctx, Symbol stream); private: - Maybe SoftSyncStream(const one::EagerBlobObjectListPtr& eager_blob_objects, + Maybe SoftSyncStream(const vm::EagerBlobObjectListPtr& eager_blob_objects, Symbol stream); Maybe SoftSyncStream( std::vector>&& compute_local_dep_objects, diff --git a/oneflow/core/framework/nn_graph.cpp b/oneflow/core/framework/nn_graph.cpp index be2d7ff5c98..15abd046260 100644 --- a/oneflow/core/framework/nn_graph.cpp +++ b/oneflow/core/framework/nn_graph.cpp @@ -446,7 +446,7 @@ Maybe NNGraph::GetVariableRealBlobAfterSyncPlan() { } // Initialize or check mem_ptr_for_allocation_computation_pipelining by TouchTensors instruction. JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe { - auto eager_blob_objects = std::make_shared(); + auto eager_blob_objects = std::make_shared(); for (const auto& pair : variable_op_name2eager_blob_object_) { eager_blob_objects->push_back(pair.second->shared_from_this()); } @@ -508,7 +508,7 @@ void NNGraph::CloseRuntimeBuffers() { namespace { -Maybe MakeEagerBlobObjectList(one::EagerBlobObjectList* blob_list, +Maybe MakeEagerBlobObjectList(vm::EagerBlobObjectList* blob_list, const one::TensorTuple& tensor_list) { blob_list->reserve(tensor_list.size()); for (const auto& tensor : tensor_list) { @@ -549,18 +549,18 @@ Maybe RunLazyNNGraph(const one::TensorTuple& inputs, const one::TensorTupl CHECK_OR_RETURN(nn_graph->outputs_tensor_meta_str().at(i) == *JUST(GetTensorMetaString(outputs.at(i)))); } - one::EagerBlobObjectList input_blobs; - one::EagerBlobObjectList output_blobs; - one::EagerBlobObjectList var_blobs; + vm::EagerBlobObjectList input_blobs; + vm::EagerBlobObjectList output_blobs; + vm::EagerBlobObjectList var_blobs; JUST(MakeEagerBlobObjectList(&input_blobs, inputs)); JUST(MakeEagerBlobObjectList(&output_blobs, outputs)); JUST(MakeEagerBlobObjectList(&var_blobs, parameters)); const auto& input_blob_list_ptr = - std::make_shared(std::move(input_blobs)); + std::make_shared(std::move(input_blobs)); const auto& output_blob_list_ptr = - std::make_shared(std::move(output_blobs)); + std::make_shared(std::move(output_blobs)); const auto& var_blob_list_ptr = - std::make_shared(std::move(var_blobs)); + std::make_shared(std::move(var_blobs)); JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe { return builder->LaunchLazyJob(input_blob_list_ptr, output_blob_list_ptr, var_blob_list_ptr, nn_graph); @@ -570,7 +570,7 @@ Maybe RunLazyNNGraph(const one::TensorTuple& inputs, const one::TensorTupl Maybe SoftSyncNNGraphBuffers(const one::TensorTuple& buffers, const std::shared_ptr& nn_graph) { - const auto& eager_blob_objects = std::make_shared(); + const auto& eager_blob_objects = std::make_shared(); JUST(MakeEagerBlobObjectList(eager_blob_objects.get(), buffers)); JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe { return builder->SoftSyncNNGraphBuffers(eager_blob_objects, nn_graph); diff --git a/oneflow/core/framework/op_interpreter/eager_consistent_op_interpreter.cpp b/oneflow/core/framework/op_interpreter/eager_consistent_op_interpreter.cpp index 42e3fc12462..10a1a5e38fd 100644 --- a/oneflow/core/framework/op_interpreter/eager_consistent_op_interpreter.cpp +++ b/oneflow/core/framework/op_interpreter/eager_consistent_op_interpreter.cpp @@ -152,8 +152,8 @@ Maybe Interpret(const UserOpExpr& user_op_expr, const TensorTuple& inputs, CHECK_EQ_OR_RETURN(kernel->output_tuple_indexes4mut2_obns().size(), 0) << Error::UnimplementedError() << GetDynamicOpConsistentFailedDebugString(user_op_expr, *kernel); - std::shared_ptr input_eager_blob_objects = - std::make_shared(inputs.size()); + std::shared_ptr input_eager_blob_objects = + std::make_shared(inputs.size()); // expand lifetime of boxing outputs to the end of this function TensorTuple boxing_outputs; for (int i = 0; i < inputs.size(); ++i) { @@ -172,8 +172,8 @@ Maybe Interpret(const UserOpExpr& user_op_expr, const TensorTuple& inputs, } // Do nothing if the `parallel_desc` doesn't cover current ProcessCtx. if (!parallel_id.has_value()) { return Maybe::Ok(); } - std::shared_ptr output_eager_blob_objects = - std::make_shared(outputs->size()); + std::shared_ptr output_eager_blob_objects = + std::make_shared(outputs->size()); for (int i = 0; i < outputs->size(); ++i) { const auto& local_tensor = JUST(outputs->at(i)->cur_rank_phy_tensor()); output_eager_blob_objects->at(i) = JUST(local_tensor->eager_blob_object()); diff --git a/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp b/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp index a50c9f00656..32acef50131 100644 --- a/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp +++ b/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp @@ -90,12 +90,12 @@ Maybe NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& in OF_PROFILER_RANGE_GUARD("NaiveInterpret"); OF_PROFILER_RANGE_PUSH("init inputs"); const auto& attrs = ctx.attrs; - std::shared_ptr input_eager_blob_objects = - std::make_shared(inputs.size()); + std::shared_ptr input_eager_blob_objects = + std::make_shared(inputs.size()); for (int i = 0; i < inputs.size(); i++) { const auto& input_device = JUST(inputs.at(i)->device()); if (i > 0) { - CHECK_OR_RETURN(*default_device == *input_device) + CHECK_OR_RETURN(default_device == input_device) << Error::RuntimeError() << "Expected all tensors to be on the same device, but found at least two devices, " << default_device->ToString() << " (positional 0) and " << input_device->ToString() @@ -105,8 +105,8 @@ Maybe NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& in } OF_PROFILER_RANGE_POP(); OF_PROFILER_RANGE_PUSH("init outputs"); - std::shared_ptr output_eager_blob_objects = - std::make_shared(outputs->size()); + std::shared_ptr output_eager_blob_objects = + std::make_shared(outputs->size()); auto* output_tensor_metas = ThreadLocalDefaultOutputMutTensorMetas(outputs->size()); for (int i = 0; i < outputs->size(); i++) { if (!outputs->at(i)) { diff --git a/oneflow/core/vm/touch_tensors_instruction_type.cpp b/oneflow/core/vm/touch_tensors_instruction_type.cpp index d59b605ac61..b395b5063f6 100644 --- a/oneflow/core/vm/touch_tensors_instruction_type.cpp +++ b/oneflow/core/vm/touch_tensors_instruction_type.cpp @@ -20,7 +20,7 @@ namespace oneflow { namespace vm { TouchTensorsPhyInstrOperand::TouchTensorsPhyInstrOperand( - const one::EagerBlobObjectList& eager_blob_objects) + const vm::EagerBlobObjectList& eager_blob_objects) : eager_blob_objects_(eager_blob_objects) { const auto& Insert = SetInserter(&input_dependences_); for (const auto& eager_blob_object : eager_blob_objects_) { diff --git a/oneflow/core/vm/touch_tensors_instruction_type.h b/oneflow/core/vm/touch_tensors_instruction_type.h index e2ada6ab594..0e4c1571ebb 100644 --- a/oneflow/core/vm/touch_tensors_instruction_type.h +++ b/oneflow/core/vm/touch_tensors_instruction_type.h @@ -27,7 +27,7 @@ class Instruction; class TouchTensorsPhyInstrOperand final : public PhyInstrOperand { public: - TouchTensorsPhyInstrOperand(const one::EagerBlobObjectList& eager_blob_objects); + TouchTensorsPhyInstrOperand(const vm::EagerBlobObjectList& eager_blob_objects); const DependenceVector& input_dependences() const override { return input_dependences_; } const DependenceVector& output_dependences() const override { @@ -40,7 +40,7 @@ class TouchTensorsPhyInstrOperand final : public PhyInstrOperand { } private: - one::EagerBlobObjectList eager_blob_objects_; + vm::EagerBlobObjectList eager_blob_objects_; DependenceVector input_dependences_; }; From bf87255f40836a4b630a14c2748a987a6e900b0c Mon Sep 17 00:00:00 2001 From: lixinqi Date: Sat, 9 Jul 2022 19:36:41 +0800 Subject: [PATCH 06/67] refactor signature of InstructionsBuiler::Call --- oneflow/core/eager/call_context.h | 16 ++++----- .../core/eager/op_call_instruction_type.cpp | 2 +- .../core/eager/op_call_phy_instr_operand.cpp | 15 ++++---- .../core/eager/op_call_phy_instr_operand.h | 10 +++--- .../core/framework/instructions_builder.cpp | 36 ++++++++++--------- oneflow/core/framework/instructions_builder.h | 10 +++--- .../eager_consistent_op_interpreter.cpp | 14 ++++---- .../eager_local_op_interpreter.cpp | 22 ++++++------ oneflow/user/kernels/stateful_opkernel.cpp | 12 +++---- 9 files changed, 70 insertions(+), 67 deletions(-) diff --git a/oneflow/core/eager/call_context.h b/oneflow/core/eager/call_context.h index b709771195e..bd0d2ab054f 100644 --- a/oneflow/core/eager/call_context.h +++ b/oneflow/core/eager/call_context.h @@ -71,13 +71,13 @@ class TmpTensor final : public user_op::Tensor { class CallContext { public: CallContext( - ComposedAttrMap&& composed_attrs, const vm::EagerBlobObjectListPtr& inputs, - const vm::EagerBlobObjectListPtr& outputs, + ComposedAttrMap&& composed_attrs, vm::EagerBlobObjectList&& inputs, + vm::EagerBlobObjectList&& outputs, const std::shared_ptr& consistent_tensor_infer_result, const one::OpExprInterpContext& op_interp_ctx, const std::shared_ptr& mem_case) : composed_attrs_(std::move(composed_attrs)), - inputs_(inputs), - outputs_(outputs), + inputs_(std::move(inputs)), + outputs_(std::move(outputs)), consistent_tensor_infer_result_(consistent_tensor_infer_result), op_interp_ctx_(op_interp_ctx), tmp_tensor_(mem_case) {} @@ -85,8 +85,8 @@ class CallContext { ~CallContext() = default; const ComposedAttrMap& composed_attrs() const { return composed_attrs_; } - const vm::EagerBlobObjectListPtr& inputs() const { return inputs_; } - const vm::EagerBlobObjectListPtr& outputs() const { return outputs_; } + const vm::EagerBlobObjectList& inputs() const { return inputs_; } + const vm::EagerBlobObjectList& outputs() const { return outputs_; } const std::shared_ptr& consistent_tensor_infer_result() const { return consistent_tensor_infer_result_; @@ -96,8 +96,8 @@ class CallContext { private: const ComposedAttrMap composed_attrs_; - const vm::EagerBlobObjectListPtr inputs_; - const vm::EagerBlobObjectListPtr outputs_; + const vm::EagerBlobObjectList inputs_; + const vm::EagerBlobObjectList outputs_; const std::shared_ptr consistent_tensor_infer_result_; const one::OpExprInterpContext op_interp_ctx_; TmpTensor tmp_tensor_; diff --git a/oneflow/core/eager/op_call_instruction_type.cpp b/oneflow/core/eager/op_call_instruction_type.cpp index f5a557be0dd..3a75a36e9df 100644 --- a/oneflow/core/eager/op_call_instruction_type.cpp +++ b/oneflow/core/eager/op_call_instruction_type.cpp @@ -99,7 +99,7 @@ struct OpCallInstructionUtil final { static inline Maybe AllocateOutputBlobsMemory(OpCallPhyInstrOperand* operand, DeviceCtx* device_ctx) { OF_PROFILER_RANGE_GUARD("AllocateOutputBlobsMemory"); - for (const auto& blob_object : *operand->outputs()) { + for (const auto& blob_object : operand->outputs()) { JUST(blob_object->TryAllocateBlobBodyMemory(device_ctx)); } return Maybe::Ok(); diff --git a/oneflow/core/eager/op_call_phy_instr_operand.cpp b/oneflow/core/eager/op_call_phy_instr_operand.cpp index 0a334b55dad..ba81bdca7a8 100644 --- a/oneflow/core/eager/op_call_phy_instr_operand.cpp +++ b/oneflow/core/eager/op_call_phy_instr_operand.cpp @@ -24,13 +24,14 @@ namespace vm { OpCallPhyInstrOperand::OpCallPhyInstrOperand( vm::Stream* vm_stream, const std::shared_ptr& opkernel, - const vm::EagerBlobObjectListPtr& inputs, const vm::EagerBlobObjectListPtr& outputs, + vm::EagerBlobObjectList&& inputs, vm::EagerBlobObjectList&& outputs, const std::shared_ptr& consistent_tensor_infer_result, const one::OpExprInterpContext& op_interp_ctx, const one::DevVmDepObjectConsumeMode dev_vm_dep_object_consume_mode) : vm_stream_(vm_stream), - call_ctx_(ComposedAttrMap(op_interp_ctx.attrs, opkernel->base_attrs()), inputs, outputs, - consistent_tensor_infer_result, op_interp_ctx, opkernel->mem_case()), + call_ctx_(ComposedAttrMap(op_interp_ctx.attrs, opkernel->base_attrs()), std::move(inputs), + std::move(outputs), consistent_tensor_infer_result, op_interp_ctx, + opkernel->mem_case()), opkernel_(opkernel), user_opkernel_(nullptr), infer_tmp_size_fn_(nullptr), @@ -52,7 +53,7 @@ void OpCallPhyInstrOperand::ForEachConstDependence( const std::function& DoEach) const { const auto& input_list = inputs(); for (int64_t index : opkernel().input_tuple_indexes4const_ibns()) { - const auto& input = input_list->at(index); + const auto& input = input_list.at(index); DoEach(CHECK_JUST(input->compute_local_dep_object())); } } @@ -80,12 +81,12 @@ void OpCallPhyInstrOperand::ForEachMutDependence( const auto& input_list = inputs(); for (int64_t index : opkernel().input_tuple_indexes4mut_ibns()) { - const auto& input = input_list->at(index); + const auto& input = input_list.at(index); DoEach(CHECK_JUST(input->compute_local_dep_object())); } const auto& output_list = outputs(); for (int64_t index : opkernel().output_tuple_indexes4mut_obns()) { - const auto& output = output_list->at(index); + const auto& output = output_list.at(index); DoEach(CHECK_JUST(output->compute_local_dep_object())); } } @@ -94,7 +95,7 @@ void OpCallPhyInstrOperand::ForEachMut2Dependence( const std::function& DoEach) const { const auto& output_list = outputs(); for (int64_t index : opkernel().output_tuple_indexes4mut2_obns()) { - const auto& output = output_list->at(index); + const auto& output = output_list.at(index); DoEach(CHECK_JUST(output->compute_local_dep_object())); } } diff --git a/oneflow/core/eager/op_call_phy_instr_operand.h b/oneflow/core/eager/op_call_phy_instr_operand.h index 0557ae57e7a..b8f1eb3b075 100644 --- a/oneflow/core/eager/op_call_phy_instr_operand.h +++ b/oneflow/core/eager/op_call_phy_instr_operand.h @@ -49,8 +49,8 @@ class OpCallPhyInstrOperand final : public vm::PhyInstrOperand { } const one::StatefulOpKernel& opkernel() const { return *opkernel_; } - const vm::EagerBlobObjectListPtr& inputs() const { return call_ctx_.inputs(); } - const vm::EagerBlobObjectListPtr& outputs() const { return call_ctx_.outputs(); } + const vm::EagerBlobObjectList& inputs() const { return call_ctx_.inputs(); } + const vm::EagerBlobObjectList& outputs() const { return call_ctx_.outputs(); } const AttrMap& attrs() const { return call_ctx_.op_interp_ctx().attrs; } const one::OpExprInterpContext& op_interp_ctx() const { return call_ctx_.op_interp_ctx(); } const one::DevVmDepObjectConsumeMode& dev_vm_dep_object_consume_mode() const { @@ -61,7 +61,7 @@ class OpCallPhyInstrOperand final : public vm::PhyInstrOperand { template Maybe ForEachOutputTensor(const DoEachT& DoEach) { - for (const auto& output : *outputs()) { JUST(DoEach(output.get())); } + for (const auto& output : outputs()) { JUST(DoEach(output.get())); } return Maybe::Ok(); } @@ -86,14 +86,14 @@ class OpCallPhyInstrOperand final : public vm::PhyInstrOperand { eager::CallContext* mut_call_ctx() { return &call_ctx_; } void ForEachInputEagerBlobObjects(void (*DoEach)(EagerBlobObject*)) const override { - for (const auto& eager_blob_object : *call_ctx_.inputs()) { DoEach(eager_blob_object.get()); } + for (const auto& eager_blob_object : call_ctx_.inputs()) { DoEach(eager_blob_object.get()); } } private: friend struct OpCallInstructionUtil; OpCallPhyInstrOperand( vm::Stream* vm_stream, const std::shared_ptr& opkernel, - const vm::EagerBlobObjectListPtr& inputs, const vm::EagerBlobObjectListPtr& outputs, + vm::EagerBlobObjectList&& inputs, vm::EagerBlobObjectList&& outputs, const std::shared_ptr& consistent_tensor_infer_result, const one::OpExprInterpContext& op_interp_ctx, const one::DevVmDepObjectConsumeMode dev_vm_dep_object_consume_mode); diff --git a/oneflow/core/framework/instructions_builder.cpp b/oneflow/core/framework/instructions_builder.cpp index 4beb8e095a6..ff73fdbd02b 100644 --- a/oneflow/core/framework/instructions_builder.cpp +++ b/oneflow/core/framework/instructions_builder.cpp @@ -205,7 +205,7 @@ Maybe InstructionsBuilder::SoftSyncNNGraphBuffers( const vm::EagerBlobObjectListPtr& eager_blob_objects, const std::shared_ptr& nn_graph) { const auto& stream = JUST(GetCriticalSectionStream()); - JUST(SoftSyncStream(eager_blob_objects, stream)); + JUST(SoftSyncStream(*eager_blob_objects, stream)); return Maybe::Ok(); } @@ -359,31 +359,33 @@ Maybe InstructionsBuilder::BuildScopeByProtoStrSetter( } Maybe InstructionsBuilder::Call(const std::shared_ptr& opkernel, - const vm::EagerBlobObjectListPtr& input_eager_blob_objects, - const vm::EagerBlobObjectListPtr& output_eager_blob_objects, + vm::EagerBlobObjectList&& input_eager_blob_objects, + vm::EagerBlobObjectList&& output_eager_blob_objects, const one::OpExprInterpContext& ctx, Symbol stream) { - return Call(opkernel, input_eager_blob_objects, output_eager_blob_objects, nullptr, ctx, stream); + return Call(opkernel, std::move(input_eager_blob_objects), std::move(output_eager_blob_objects), + nullptr, ctx, stream); } Maybe InstructionsBuilder::Call( const std::shared_ptr& opkernel, - const vm::EagerBlobObjectListPtr& input_eager_blob_objects, - const vm::EagerBlobObjectListPtr& output_eager_blob_objects, + vm::EagerBlobObjectList&& input_eager_blob_objects, + vm::EagerBlobObjectList&& output_eager_blob_objects, const std::shared_ptr& consistent_tensor_infer_result, const one::OpExprInterpContext& ctx, Symbol stream) { JUST(SoftSyncStream(output_eager_blob_objects, stream)); JUST(SoftSyncStream(input_eager_blob_objects, stream)); + for (const auto& output : output_eager_blob_objects) { + if (!output->producer_stream().has_value()) { JUST(output->init_producer_stream(stream)); } + output->set_last_used_stream(stream); + } auto* vm_stream = JUST(Singleton::Get()->GetVmStream(stream)); auto phy_instr_operand = JUST(vm::OpCallPhyInstrOperand::New( - vm_stream, opkernel, input_eager_blob_objects, output_eager_blob_objects, - consistent_tensor_infer_result, ctx, *one::CurrentDevVmDepObjectConsumeMode())); + vm_stream, opkernel, std::move(input_eager_blob_objects), + std::move(output_eager_blob_objects), consistent_tensor_infer_result, ctx, + *one::CurrentDevVmDepObjectConsumeMode())); auto instruction = intrusive::make_shared( vm_stream, SingletonPtr(), phy_instr_operand); instruction_list_->EmplaceBack(std::move(instruction)); - for (const auto& output : *output_eager_blob_objects) { - if (!output->producer_stream().has_value()) { JUST(output->init_producer_stream(stream)); } - output->set_last_used_stream(stream); - } return Maybe::Ok(); } @@ -438,10 +440,10 @@ Maybe InstructionsBuilder::TouchTensors(const vm::EagerBlobObjectListPtr& return Maybe::Ok(); } -Maybe InstructionsBuilder::SoftSyncStream( - const vm::EagerBlobObjectListPtr& eager_blob_objects, Symbol stream) { +Maybe InstructionsBuilder::SoftSyncStream(const vm::EagerBlobObjectList& eager_blob_objects, + Symbol stream) { SmallSet> last_used_streams; - for (const auto& eager_blob_object : *eager_blob_objects) { + for (const auto& eager_blob_object : eager_blob_objects) { const auto& opt_last_used_stream = eager_blob_object->last_used_stream(); if (unlikely(!opt_last_used_stream.has_value())) { continue; } const auto& last_used_stream = JUST(opt_last_used_stream); @@ -449,8 +451,8 @@ Maybe InstructionsBuilder::SoftSyncStream( } for (const auto& last_used_stream : last_used_streams) { std::vector> dep_objects; - dep_objects.reserve(eager_blob_objects->size()); - for (const auto& eager_blob_object : *eager_blob_objects) { + dep_objects.reserve(eager_blob_objects.size()); + for (const auto& eager_blob_object : eager_blob_objects) { const auto& opt_last_used_stream = eager_blob_object->last_used_stream(); if (unlikely(!opt_last_used_stream.has_value())) { continue; } if (JUST(opt_last_used_stream) == last_used_stream) { diff --git a/oneflow/core/framework/instructions_builder.h b/oneflow/core/framework/instructions_builder.h index e832b0e7de2..85cb18cca12 100644 --- a/oneflow/core/framework/instructions_builder.h +++ b/oneflow/core/framework/instructions_builder.h @@ -118,19 +118,19 @@ class InstructionsBuilder : public std::enable_shared_from_this& StrSetter); Maybe Call(const std::shared_ptr& opkernel, - const vm::EagerBlobObjectListPtr& input_eager_blob_objects, - const vm::EagerBlobObjectListPtr& output_eager_blob_objects, + vm::EagerBlobObjectList&& input_eager_blob_objects, + vm::EagerBlobObjectList&& output_eager_blob_objects, const one::OpExprInterpContext& ctx, Symbol stream); Maybe Call( const std::shared_ptr& opkernel, - const vm::EagerBlobObjectListPtr& input_eager_blob_objects, - const vm::EagerBlobObjectListPtr& output_eager_blob_objects, + vm::EagerBlobObjectList&& input_eager_blob_objects, + vm::EagerBlobObjectList&& output_eager_blob_objects, const std::shared_ptr& consistent_tensor_infer_result, const one::OpExprInterpContext& ctx, Symbol stream); private: - Maybe SoftSyncStream(const vm::EagerBlobObjectListPtr& eager_blob_objects, + Maybe SoftSyncStream(const vm::EagerBlobObjectList& eager_blob_objects, Symbol stream); Maybe SoftSyncStream( std::vector>&& compute_local_dep_objects, diff --git a/oneflow/core/framework/op_interpreter/eager_consistent_op_interpreter.cpp b/oneflow/core/framework/op_interpreter/eager_consistent_op_interpreter.cpp index 10a1a5e38fd..c8a2bcd6e42 100644 --- a/oneflow/core/framework/op_interpreter/eager_consistent_op_interpreter.cpp +++ b/oneflow/core/framework/op_interpreter/eager_consistent_op_interpreter.cpp @@ -152,8 +152,7 @@ Maybe Interpret(const UserOpExpr& user_op_expr, const TensorTuple& inputs, CHECK_EQ_OR_RETURN(kernel->output_tuple_indexes4mut2_obns().size(), 0) << Error::UnimplementedError() << GetDynamicOpConsistentFailedDebugString(user_op_expr, *kernel); - std::shared_ptr input_eager_blob_objects = - std::make_shared(inputs.size()); + vm::EagerBlobObjectList input_eager_blob_objects(inputs.size()); // expand lifetime of boxing outputs to the end of this function TensorTuple boxing_outputs; for (int i = 0; i < inputs.size(); ++i) { @@ -168,18 +167,19 @@ Maybe Interpret(const UserOpExpr& user_op_expr, const TensorTuple& inputs, boxing_outputs.emplace_back(input); } const auto& local_tensor = JUST(input->cur_rank_phy_tensor()); - input_eager_blob_objects->at(i) = JUST(local_tensor->eager_blob_object()); + input_eager_blob_objects.at(i) = JUST(local_tensor->eager_blob_object()); } // Do nothing if the `parallel_desc` doesn't cover current ProcessCtx. if (!parallel_id.has_value()) { return Maybe::Ok(); } - std::shared_ptr output_eager_blob_objects = - std::make_shared(outputs->size()); + vm::EagerBlobObjectList output_eager_blob_objects(outputs->size()); for (int i = 0; i < outputs->size(); ++i) { const auto& local_tensor = JUST(outputs->at(i)->cur_rank_phy_tensor()); - output_eager_blob_objects->at(i) = JUST(local_tensor->eager_blob_object()); + output_eager_blob_objects.at(i) = JUST(local_tensor->eager_blob_object()); } + auto* inputs_ptr = &input_eager_blob_objects; + auto* outputs_ptr = &output_eager_blob_objects; JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe { - return builder->Call(kernel, input_eager_blob_objects, output_eager_blob_objects, result, ctx, + return builder->Call(kernel, std::move(*inputs_ptr), std::move(*outputs_ptr), result, ctx, result->stream()); })); return Maybe::Ok(); diff --git a/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp b/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp index 32acef50131..94aa89bf1a7 100644 --- a/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp +++ b/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp @@ -90,8 +90,7 @@ Maybe NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& in OF_PROFILER_RANGE_GUARD("NaiveInterpret"); OF_PROFILER_RANGE_PUSH("init inputs"); const auto& attrs = ctx.attrs; - std::shared_ptr input_eager_blob_objects = - std::make_shared(inputs.size()); + vm::EagerBlobObjectList input_eager_blob_objects(inputs.size()); for (int i = 0; i < inputs.size(); i++) { const auto& input_device = JUST(inputs.at(i)->device()); if (i > 0) { @@ -101,12 +100,11 @@ Maybe NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& in << default_device->ToString() << " (positional 0) and " << input_device->ToString() << " (positional " << i << ")!"; } - input_eager_blob_objects->at(i) = JUST(inputs.at(i)->eager_blob_object()); + input_eager_blob_objects.at(i) = JUST(inputs.at(i)->eager_blob_object()); } OF_PROFILER_RANGE_POP(); OF_PROFILER_RANGE_PUSH("init outputs"); - std::shared_ptr output_eager_blob_objects = - std::make_shared(outputs->size()); + vm::EagerBlobObjectList output_eager_blob_objects(outputs->size()); auto* output_tensor_metas = ThreadLocalDefaultOutputMutTensorMetas(outputs->size()); for (int i = 0; i < outputs->size(); i++) { if (!outputs->at(i)) { @@ -116,7 +114,7 @@ Maybe NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& in } else { bool has_eager_blob_object = JUST(outputs->at(i)->has_eager_blob_object()); CHECK_OR_RETURN(has_eager_blob_object); - output_eager_blob_objects->at(i) = JUST(outputs->at(i)->eager_blob_object()); + output_eager_blob_objects.at(i) = JUST(outputs->at(i)->eager_blob_object()); } } Symbol stream; @@ -153,9 +151,9 @@ Maybe NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& in OF_PROFILER_RANGE_POP(); OF_PROFILER_RANGE_PUSH("init output eager_blob_objects"); - for (int i = 0; i < output_eager_blob_objects->size(); i++) { + for (int i = 0; i < output_eager_blob_objects.size(); i++) { auto* tensor_impl = JUST(TensorImpl4Tensor(outputs->at(i))); - if (!output_eager_blob_objects->at(i)) { + if (!output_eager_blob_objects.at(i)) { // NOTE: if op support stride(non-contiguous input), then output tensor's stride // should be inferred in InferLogicalTensorDesc. // otherwise, it will be set here(according to shape). @@ -165,7 +163,7 @@ Maybe NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& in } const auto& dep_object = NewLocalDepObject(); JUST(tensor_impl->InitEagerBlobObject(dep_object)); - output_eager_blob_objects->at(i) = JUST(tensor_impl->eager_blob_object()); + output_eager_blob_objects.at(i) = JUST(tensor_impl->eager_blob_object()); } else { // output i is inplaced. // check thread_local TensorMeta and tensor_impl TensorMeta. @@ -183,12 +181,14 @@ Maybe NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& in kernel->set_need_check_mem_case(need_check_mem_case); for (int64_t index : kernel->output_tuple_indexes4mut2_obns()) { - output_eager_blob_objects->at(index)->set_is_shape_synced(false); + output_eager_blob_objects.at(index)->set_is_shape_synced(false); } OF_PROFILER_RANGE_POP(); OF_PROFILER_RANGE_PUSH("PhysicalRun"); + auto* inputs_ptr = &input_eager_blob_objects; + auto* outputs_ptr = &output_eager_blob_objects; JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe { - return builder->Call(kernel, input_eager_blob_objects, output_eager_blob_objects, ctx, stream); + return builder->Call(kernel, std::move(*inputs_ptr), std::move(*outputs_ptr), ctx, stream); })); OF_PROFILER_RANGE_POP(); return Maybe::Ok(); diff --git a/oneflow/user/kernels/stateful_opkernel.cpp b/oneflow/user/kernels/stateful_opkernel.cpp index 4d4cb15f916..e83772c8f49 100644 --- a/oneflow/user/kernels/stateful_opkernel.cpp +++ b/oneflow/user/kernels/stateful_opkernel.cpp @@ -54,13 +54,13 @@ class ZeroCopyBaseContextHelper { user_op::TensorDesc* TensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, const int32_t index) const { - RETURN_IF_FOUND(*call_ctx->inputs(), *call_ctx->outputs(), .get()); + RETURN_IF_FOUND(call_ctx->inputs(), call_ctx->outputs(), .get()); return nullptr; } user_op::Tensor* Tensor4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, const int32_t index) const { - RETURN_IF_FOUND(*call_ctx->inputs(), *call_ctx->outputs(), .get()); + RETURN_IF_FOUND(call_ctx->inputs(), call_ctx->outputs(), .get()); if (arg_name == "tmp_buffer" && index == 0) { return call_ctx->mut_tmp_tensor(); } return nullptr; } @@ -788,10 +788,10 @@ Maybe StatefulOpKernel::ChooseOpKernel(eager::CallContext* call_ctx, DataType primary_dtype = kInvalidDataType; const auto& inputs = call_ctx->inputs(); const auto& outputs = call_ctx->outputs(); - if (likely(!inputs->empty())) { - primary_dtype = (*inputs)[0]->data_type(); - } else if (likely(!outputs->empty())) { - primary_dtype = (*outputs)[0]->data_type(); + if (likely(!inputs.empty())) { + primary_dtype = inputs[0]->data_type(); + } else if (likely(!outputs.empty())) { + primary_dtype = outputs[0]->data_type(); } else { // do nothing } From 2dc89b6d91aba8b1f2d71516c0361e0a3a2a7a55 Mon Sep 17 00:00:00 2001 From: lixinqi Date: Sun, 10 Jul 2022 10:14:32 +0800 Subject: [PATCH 07/67] PhysicalRun --- oneflow/core/framework/instructions_builder.cpp | 9 --------- oneflow/core/framework/instructions_builder.h | 10 +++++++++- .../op_interpreter/eager_consistent_op_interpreter.cpp | 6 ++---- .../op_interpreter/eager_local_op_interpreter.cpp | 5 ++--- 4 files changed, 13 insertions(+), 17 deletions(-) diff --git a/oneflow/core/framework/instructions_builder.cpp b/oneflow/core/framework/instructions_builder.cpp index ff73fdbd02b..24b6396595c 100644 --- a/oneflow/core/framework/instructions_builder.cpp +++ b/oneflow/core/framework/instructions_builder.cpp @@ -37,7 +37,6 @@ limitations under the License. #include "oneflow/core/eager/op_call_instruction_type.h" #include "oneflow/core/vm/barrier_instruction_type.h" #include "oneflow/core/vm/virtual_machine.h" -#include "oneflow/core/vm/vm_util.h" #include "oneflow/core/framework/consistent_tensor_infer_cache.h" #include "oneflow/core/eager/local_dep_object.h" #include "oneflow/core/eager/critical_section_instruction_type.h" @@ -611,12 +610,4 @@ Maybe InstructionsBuilder::Barrier(const std::function& Callback) return Maybe::Ok(); } -Maybe PhysicalRun(const std::function(InstructionsBuilder*)>& Build) { - vm::InstructionList instruction_list; - InstructionsBuilder instructions_builder(&instruction_list); - JUST(Build(&instructions_builder)); - JUST(vm::Run(instructions_builder.mut_instruction_list())); - return Maybe::Ok(); -} - } // namespace oneflow diff --git a/oneflow/core/framework/instructions_builder.h b/oneflow/core/framework/instructions_builder.h index 85cb18cca12..97a6db39446 100644 --- a/oneflow/core/framework/instructions_builder.h +++ b/oneflow/core/framework/instructions_builder.h @@ -28,6 +28,7 @@ limitations under the License. #include "oneflow/core/common/shape.h" #include "oneflow/core/common/blocking_then_busy.h" #include "oneflow/core/operator/op_conf_symbol.h" +#include "oneflow/core/vm/vm_util.h" namespace oneflow { @@ -149,7 +150,14 @@ class InstructionsBuilder : public std::enable_shared_from_this PhysicalRun(const std::function(InstructionsBuilder*)>& Build); +template +Maybe PhysicalRun(const CallbackT& Build) { + vm::InstructionList instruction_list; + InstructionsBuilder instructions_builder(&instruction_list); + JUST(Build(&instructions_builder)); + JUST(vm::Run(instructions_builder.mut_instruction_list())); + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/core/framework/op_interpreter/eager_consistent_op_interpreter.cpp b/oneflow/core/framework/op_interpreter/eager_consistent_op_interpreter.cpp index c8a2bcd6e42..066e7ff9f2c 100644 --- a/oneflow/core/framework/op_interpreter/eager_consistent_op_interpreter.cpp +++ b/oneflow/core/framework/op_interpreter/eager_consistent_op_interpreter.cpp @@ -176,11 +176,9 @@ Maybe Interpret(const UserOpExpr& user_op_expr, const TensorTuple& inputs, const auto& local_tensor = JUST(outputs->at(i)->cur_rank_phy_tensor()); output_eager_blob_objects.at(i) = JUST(local_tensor->eager_blob_object()); } - auto* inputs_ptr = &input_eager_blob_objects; - auto* outputs_ptr = &output_eager_blob_objects; JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe { - return builder->Call(kernel, std::move(*inputs_ptr), std::move(*outputs_ptr), result, ctx, - result->stream()); + return builder->Call(kernel, std::move(input_eager_blob_objects), + std::move(output_eager_blob_objects), result, ctx, result->stream()); })); return Maybe::Ok(); } diff --git a/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp b/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp index 94aa89bf1a7..dd64e7c4d00 100644 --- a/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp +++ b/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp @@ -185,10 +185,9 @@ Maybe NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& in } OF_PROFILER_RANGE_POP(); OF_PROFILER_RANGE_PUSH("PhysicalRun"); - auto* inputs_ptr = &input_eager_blob_objects; - auto* outputs_ptr = &output_eager_blob_objects; JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe { - return builder->Call(kernel, std::move(*inputs_ptr), std::move(*outputs_ptr), ctx, stream); + return builder->Call(kernel, std::move(input_eager_blob_objects), + std::move(output_eager_blob_objects), ctx, stream); })); OF_PROFILER_RANGE_POP(); return Maybe::Ok(); From 8ac3e434905b943ce88acafec55a27697b012664 Mon Sep 17 00:00:00 2001 From: lixinqi Date: Sun, 10 Jul 2022 12:43:20 +0800 Subject: [PATCH 08/67] refactor InstructionsBuilder::Call --- oneflow/core/eager/local_dep_object.h | 4 + .../core/eager/op_call_phy_instr_operand.cpp | 18 ++--- .../core/eager/op_call_phy_instr_operand.h | 9 ++- .../core/framework/instructions_builder.cpp | 73 ++++++++++++++----- oneflow/core/framework/instructions_builder.h | 7 +- ...ume_local_dep_object_phy_instr_operand.cpp | 25 +++++-- ...nsume_local_dep_object_phy_instr_operand.h | 27 +++---- oneflow/core/vm/phy_instr_operand.h | 3 +- 8 files changed, 110 insertions(+), 56 deletions(-) diff --git a/oneflow/core/eager/local_dep_object.h b/oneflow/core/eager/local_dep_object.h index 038743b1d6d..edfe7d73c62 100644 --- a/oneflow/core/eager/local_dep_object.h +++ b/oneflow/core/eager/local_dep_object.h @@ -20,6 +20,8 @@ limitations under the License. #include "oneflow/core/vm/vm_object.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/symbol.h" +#include "oneflow/core/common/small_vector.h" +#include "oneflow/core/common/op_args_reserved_size.h" #include "oneflow/core/framework/device.h" namespace oneflow { @@ -27,6 +29,8 @@ namespace oneflow { // LocalDepObject helps VirtualMachineEngine building instruction edges using LocalDepObject = vm::Dependence; +using DependenceVector = small_vector; + intrusive::shared_ptr NewLocalDepObject(); } // namespace oneflow diff --git a/oneflow/core/eager/op_call_phy_instr_operand.cpp b/oneflow/core/eager/op_call_phy_instr_operand.cpp index ba81bdca7a8..e29ffce8645 100644 --- a/oneflow/core/eager/op_call_phy_instr_operand.cpp +++ b/oneflow/core/eager/op_call_phy_instr_operand.cpp @@ -39,9 +39,9 @@ OpCallPhyInstrOperand::OpCallPhyInstrOperand( dev_vm_dep_object_consume_mode_(dev_vm_dep_object_consume_mode), input_dependences_(), output_dependences_() { - ForEachConstDependence(SetInserter(&input_dependences_)); - ForEachMutDependence(SetInserter(&output_dependences_)); - ForEachMut2Dependence(SetInserter(&output_dependences_)); + ForEachConstDependence([&](auto* dep) { input_dependences_.emplace_back(dep); }); + ForEachMutDependence([&](auto* dep) { output_dependences_.emplace_back(dep); }); + ForEachMut2Dependence([&](auto* dep) { output_dependences_.emplace_back(dep); }); InitStreamSequentialDependence(); } @@ -49,8 +49,8 @@ Maybe OpCallPhyInstrOperand::Init() { return mut_opkernel()->ChooseOpKernel(&call_ctx_, &user_opkernel_, &need_temp_storage_); } -void OpCallPhyInstrOperand::ForEachConstDependence( - const std::function& DoEach) const { +template +void OpCallPhyInstrOperand::ForEachConstDependence(const DoEachT& DoEach) const { const auto& input_list = inputs(); for (int64_t index : opkernel().input_tuple_indexes4const_ibns()) { const auto& input = input_list.at(index); @@ -74,8 +74,8 @@ void OpCallPhyInstrOperand::InitStreamSequentialDependence() { } } -void OpCallPhyInstrOperand::ForEachMutDependence( - const std::function& DoEach) const { +template +void OpCallPhyInstrOperand::ForEachMutDependence(const DoEachT& DoEach) const { const auto& opt_transport_dep_object = vm_stream_->transport_local_dep_object(); if (opt_transport_dep_object.has_value()) { DoEach(CHECK_JUST(opt_transport_dep_object)->get()); } @@ -91,8 +91,8 @@ void OpCallPhyInstrOperand::ForEachMutDependence( } } -void OpCallPhyInstrOperand::ForEachMut2Dependence( - const std::function& DoEach) const { +template +void OpCallPhyInstrOperand::ForEachMut2Dependence(const DoEachT& DoEach) const { const auto& output_list = outputs(); for (int64_t index : opkernel().output_tuple_indexes4mut2_obns()) { const auto& output = output_list.at(index); diff --git a/oneflow/core/eager/op_call_phy_instr_operand.h b/oneflow/core/eager/op_call_phy_instr_operand.h index b8f1eb3b075..7a9c884a9c5 100644 --- a/oneflow/core/eager/op_call_phy_instr_operand.h +++ b/oneflow/core/eager/op_call_phy_instr_operand.h @@ -68,11 +68,14 @@ class OpCallPhyInstrOperand final : public vm::PhyInstrOperand { const DependenceVector& input_dependences() const override { return input_dependences_; } const DependenceVector& output_dependences() const override { return output_dependences_; } - void ForEachConstDependence(const std::function&) const; + template + void ForEachConstDependence(const DoEachT& DoEach) const; - void ForEachMutDependence(const std::function&) const; + template + void ForEachMutDependence(const DoEachT& DoEach) const; - void ForEachMut2Dependence(const std::function&) const; + template + void ForEachMut2Dependence(const DoEachT& DoEach) const; bool need_temp_storage() const { return need_temp_storage_; } const user_op::OpKernel* user_opkernel() const { return user_opkernel_; } diff --git a/oneflow/core/framework/instructions_builder.cpp b/oneflow/core/framework/instructions_builder.cpp index 24b6396595c..c218398de45 100644 --- a/oneflow/core/framework/instructions_builder.cpp +++ b/oneflow/core/framework/instructions_builder.cpp @@ -439,33 +439,72 @@ Maybe InstructionsBuilder::TouchTensors(const vm::EagerBlobObjectListPtr& return Maybe::Ok(); } -Maybe InstructionsBuilder::SoftSyncStream(const vm::EagerBlobObjectList& eager_blob_objects, - Symbol stream) { - SmallSet> last_used_streams; - for (const auto& eager_blob_object : eager_blob_objects) { - const auto& opt_last_used_stream = eager_blob_object->last_used_stream(); - if (unlikely(!opt_last_used_stream.has_value())) { continue; } - const auto& last_used_stream = JUST(opt_last_used_stream); - if (last_used_stream != stream) { SmallSetInsert(&last_used_streams, last_used_stream); } - } - for (const auto& last_used_stream : last_used_streams) { - std::vector> dep_objects; - dep_objects.reserve(eager_blob_objects.size()); +namespace { + +template +Maybe ForEachEagerBlobObjectsNeedingSoftSync( + const vm::EagerBlobObjectList& eager_blob_objects, Symbol stream, + const DoEachT& DoEach) { + if (eager_blob_objects.size() <= kOpArgsReservedSize) { for (const auto& eager_blob_object : eager_blob_objects) { const auto& opt_last_used_stream = eager_blob_object->last_used_stream(); if (unlikely(!opt_last_used_stream.has_value())) { continue; } - if (JUST(opt_last_used_stream) == last_used_stream) { - dep_objects.emplace_back(JUST(eager_blob_object->compute_local_dep_object())); + const auto& last_used_stream = JUST(opt_last_used_stream); + if (last_used_stream != stream) { + const auto& ForEachEagerBlobObject = [&](const auto& DoEachEagerBlobObject) -> Maybe { + return DoEachEagerBlobObject(eager_blob_object); + }; + JUST(DoEach(last_used_stream, ForEachEagerBlobObject)); } - eager_blob_object->set_last_used_stream(stream); } - JUST(SoftSyncStream(std::move(dep_objects), "mut", last_used_stream)); + } else { + SmallSet> last_used_streams; + for (const auto& eager_blob_object : eager_blob_objects) { + const auto& opt_last_used_stream = eager_blob_object->last_used_stream(); + if (unlikely(!opt_last_used_stream.has_value())) { continue; } + const auto& last_used_stream = JUST(opt_last_used_stream); + if (last_used_stream != stream) { SmallSetInsert(&last_used_streams, last_used_stream); } + } + for (const auto& last_used_stream : last_used_streams) { + const auto& ForEachEagerBlobObject = [&](const auto& DoEachEagerBlobObject) -> Maybe { + for (const auto& eager_blob_object : eager_blob_objects) { + const auto& opt_stream = eager_blob_object->last_used_stream(); + if (unlikely(!opt_stream.has_value())) { continue; } + if (JUST(opt_stream) == last_used_stream) { + JUST(DoEachEagerBlobObject(eager_blob_object)); + } + } + return Maybe::Ok(); + }; + JUST(DoEach(last_used_stream, ForEachEagerBlobObject)); + } + } + return Maybe::Ok(); +} + +} // namespace + +Maybe InstructionsBuilder::SoftSyncStream(const vm::EagerBlobObjectList& eager_blob_objects, + Symbol stream) { + JUST(ForEachEagerBlobObjectsNeedingSoftSync( + eager_blob_objects, stream, + [&](Symbol last_used_stream, const auto& ForEachEagerBlobObject) -> Maybe { + small_vector, kOpArgsReservedSize> dep_objects{}; + JUST(ForEachEagerBlobObject([&](const auto& eager_blob_object) -> Maybe { + dep_objects.emplace_back(JUST(eager_blob_object->compute_local_dep_object())); + return Maybe::Ok(); + })); + return SoftSyncStream(std::move(dep_objects), "mut", last_used_stream); + })); + for (const auto& eager_blob_object : eager_blob_objects) { + eager_blob_object->set_last_used_stream(stream); } return Maybe::Ok(); } Maybe InstructionsBuilder::SoftSyncStream( - std::vector>&& compute_local_dep_objects, + small_vector, kOpArgsReservedSize>&& + compute_local_dep_objects, const std::string& modifier, Symbol last_used_stream) { DeviceType device_type = last_used_stream->device()->enum_type(); if (!NeedSoftSync::Visit(last_used_stream->stream_role(), device_type)) { diff --git a/oneflow/core/framework/instructions_builder.h b/oneflow/core/framework/instructions_builder.h index 97a6db39446..6f85f4ed6cd 100644 --- a/oneflow/core/framework/instructions_builder.h +++ b/oneflow/core/framework/instructions_builder.h @@ -18,6 +18,7 @@ limitations under the License. #include "oneflow/core/eager/op_call_phy_instr_operand.h" #include "oneflow/core/eager/lazy_job_phy_instr_operand.h" +#include "oneflow/core/eager/local_dep_object.h" #include "oneflow/core/vm/instruction.h" #include "oneflow/core/job/job_desc.h" #include "oneflow/core/job/parallel_desc.h" @@ -133,9 +134,9 @@ class InstructionsBuilder : public std::enable_shared_from_this SoftSyncStream(const vm::EagerBlobObjectList& eager_blob_objects, Symbol stream); - Maybe SoftSyncStream( - std::vector>&& compute_local_dep_objects, - const std::string& modifier, Symbol stream); + Maybe SoftSyncStream(small_vector, + kOpArgsReservedSize>&& compute_local_dep_objects, + const std::string& modifier, Symbol stream); private: template diff --git a/oneflow/core/vm/consume_local_dep_object_phy_instr_operand.cpp b/oneflow/core/vm/consume_local_dep_object_phy_instr_operand.cpp index 103cbfea259..fc484588a0b 100644 --- a/oneflow/core/vm/consume_local_dep_object_phy_instr_operand.cpp +++ b/oneflow/core/vm/consume_local_dep_object_phy_instr_operand.cpp @@ -20,22 +20,35 @@ namespace oneflow { namespace vm { -void ConsumeLocalDepObjectPhyInstrOperand::ForEachConstDependence( - const std::function& DoEach) const { +ConsumeLocalDepObjectPhyInstrOperand::ConsumeLocalDepObjectPhyInstrOperand( + small_vector, kOpArgsReservedSize>&& + compute_local_dep_objects, + const std::string& modifier) + : compute_local_dep_objects_(std::move(compute_local_dep_objects)), + modifier_(modifier), + input_dependences_(), + output_dependences_() { + ForEachConstDependence([&](auto* dep) { input_dependences_.emplace_back(dep); }); + ForEachMutDependence([&](auto* dep) { output_dependences_.emplace_back(dep); }); + ForEachMut2Dependence([&](auto* dep) { output_dependences_.emplace_back(dep); }); + stream_sequential_dependence_ = nullptr; +} +template +void ConsumeLocalDepObjectPhyInstrOperand::ForEachConstDependence(const DoEachT& DoEach) const { if (modifier_ == "const") { for (const auto& dep : compute_local_dep_objects_) { DoEach(dep.get()); } } } -void ConsumeLocalDepObjectPhyInstrOperand::ForEachMutDependence( - const std::function& DoEach) const { +template +void ConsumeLocalDepObjectPhyInstrOperand::ForEachMutDependence(const DoEachT& DoEach) const { if (modifier_ == "mut") { for (const auto& dep : compute_local_dep_objects_) { DoEach(dep.get()); } } } -void ConsumeLocalDepObjectPhyInstrOperand::ForEachMut2Dependence( - const std::function& DoEach) const { +template +void ConsumeLocalDepObjectPhyInstrOperand::ForEachMut2Dependence(const DoEachT& DoEach) const { if (modifier_ == "mut2") { for (const auto& dep : compute_local_dep_objects_) { DoEach(dep.get()); } } diff --git a/oneflow/core/vm/consume_local_dep_object_phy_instr_operand.h b/oneflow/core/vm/consume_local_dep_object_phy_instr_operand.h index d2c97baa495..e3d5fefa267 100644 --- a/oneflow/core/vm/consume_local_dep_object_phy_instr_operand.h +++ b/oneflow/core/vm/consume_local_dep_object_phy_instr_operand.h @@ -27,33 +27,28 @@ namespace vm { class ConsumeLocalDepObjectPhyInstrOperand : public PhyInstrOperand { public: ConsumeLocalDepObjectPhyInstrOperand( - std::vector>&& compute_local_dep_objects, - const std::string& modifier) - : compute_local_dep_objects_(std::move(compute_local_dep_objects)), - modifier_(modifier), - input_dependences_(), - output_dependences_() { - ForEachConstDependence(SetInserter(&input_dependences_)); - ForEachMutDependence(SetInserter(&output_dependences_)); - ForEachMut2Dependence(SetInserter(&output_dependences_)); - stream_sequential_dependence_ = nullptr; - } - + small_vector, kOpArgsReservedSize>&& + compute_local_dep_objects, + const std::string& modifier); ~ConsumeLocalDepObjectPhyInstrOperand() = default; const DependenceVector& input_dependences() const override { return input_dependences_; } const DependenceVector& output_dependences() const override { return output_dependences_; } - void ForEachConstDependence(const std::function&) const; + template + void ForEachConstDependence(const DoEachT& DoEach) const; - void ForEachMutDependence(const std::function&) const; + template + void ForEachMutDependence(const DoEachT& DoEach) const; - void ForEachMut2Dependence(const std::function&) const; + template + void ForEachMut2Dependence(const DoEachT& DoEach) const; void ForEachInputEagerBlobObjects(void (*DoEach)(EagerBlobObject*)) const override {} private: - std::vector> compute_local_dep_objects_; + small_vector, kOpArgsReservedSize> + compute_local_dep_objects_; const std::string modifier_; DependenceVector input_dependences_; DependenceVector output_dependences_; diff --git a/oneflow/core/vm/phy_instr_operand.h b/oneflow/core/vm/phy_instr_operand.h index 5098396ed59..df979e02b2b 100644 --- a/oneflow/core/vm/phy_instr_operand.h +++ b/oneflow/core/vm/phy_instr_operand.h @@ -21,6 +21,7 @@ limitations under the License. #include #include #include "oneflow/core/intrusive/intrusive.h" +#include "oneflow/core/eager/local_dep_object.h" namespace oneflow { namespace vm { @@ -28,8 +29,6 @@ namespace vm { class Dependence; class EagerBlobObject; -using DependenceVector = std::vector; - // physical instruction operand class PhyInstrOperand { public: From f87dcdf78b24f0dd8c032d69bc765956a712f013 Mon Sep 17 00:00:00 2001 From: lixinqi Date: Sun, 10 Jul 2022 18:44:11 +0800 Subject: [PATCH 09/67] remove unused StatefulOpKernel::need_check_mem_case --- oneflow/core/eager/op_call_phy_instr_operand.cpp | 2 ++ .../framework/op_interpreter/eager_local_op_interpreter.cpp | 3 --- .../core/framework/op_interpreter/op_interpreter_util.cpp | 2 -- oneflow/user/kernels/stateful_opkernel.cpp | 1 - oneflow/user/kernels/stateful_opkernel.h | 5 ----- 5 files changed, 2 insertions(+), 11 deletions(-) diff --git a/oneflow/core/eager/op_call_phy_instr_operand.cpp b/oneflow/core/eager/op_call_phy_instr_operand.cpp index e29ffce8645..d0db5bbdd25 100644 --- a/oneflow/core/eager/op_call_phy_instr_operand.cpp +++ b/oneflow/core/eager/op_call_phy_instr_operand.cpp @@ -18,6 +18,7 @@ limitations under the License. #include "oneflow/core/eager/dev_vm_dep_object_consume_mode.h" #include "oneflow/core/framework/stream_is_comm_net_stream.h" #include "oneflow/core/vm/stream.h" +#include "oneflow/core/profiler/profiler.h" namespace oneflow { namespace vm { @@ -46,6 +47,7 @@ OpCallPhyInstrOperand::OpCallPhyInstrOperand( } Maybe OpCallPhyInstrOperand::Init() { + OF_PROFILER_RANGE_GUARD("OpCallPhyInstrOperand::Init"); return mut_opkernel()->ChooseOpKernel(&call_ctx_, &user_opkernel_, &need_temp_storage_); } diff --git a/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp b/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp index dd64e7c4d00..486c69131ef 100644 --- a/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp +++ b/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp @@ -118,7 +118,6 @@ Maybe NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& in } } Symbol stream; - bool need_check_mem_case = true; OF_PROFILER_RANGE_POP(); OF_PROFILER_RANGE_PUSH("infer devices"); @@ -130,7 +129,6 @@ Maybe NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& in *JUST(tensor_impl->mut_device()) = default_device; } } else { - need_check_mem_case = false; stream = JUST(user_op_expr.InferDeviceAndStream(attrs, inputs, outputs)); } @@ -178,7 +176,6 @@ Maybe NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& in OF_PROFILER_RANGE_POP(); OF_PROFILER_RANGE_PUSH("init opkernel"); const auto& kernel = JUST(user_op_expr.MutKernel4Stream(stream)); - kernel->set_need_check_mem_case(need_check_mem_case); for (int64_t index : kernel->output_tuple_indexes4mut2_obns()) { output_eager_blob_objects.at(index)->set_is_shape_synced(false); diff --git a/oneflow/core/framework/op_interpreter/op_interpreter_util.cpp b/oneflow/core/framework/op_interpreter/op_interpreter_util.cpp index 207e56593b5..d074db74a36 100644 --- a/oneflow/core/framework/op_interpreter/op_interpreter_util.cpp +++ b/oneflow/core/framework/op_interpreter/op_interpreter_util.cpp @@ -126,7 +126,6 @@ Maybe GetInterpreter(const TensorTuple& inputs, const OpExp template<> /* static */ Maybe OpInterpUtil::Dispatch( const OpExpr& op_expr, const TensorTuple& inputs, const OpExprInterpContext& ctx) { - OF_PROFILER_RANGE_GUARD("Dispatch"); auto outputs = std::make_shared(op_expr.output_size()); JUST(Dispatch(op_expr, inputs, outputs.get(), ctx)); return outputs; @@ -136,7 +135,6 @@ template<> /* static */ Maybe OpInterpUtil::Dispatch(const OpExpr& op_expr, const TensorTuple& inputs, const OpExprInterpContext& ctx) { - OF_PROFILER_RANGE_GUARD("Dispatch"); return JUST(Dispatch(op_expr, inputs, ctx))->at(0); } diff --git a/oneflow/user/kernels/stateful_opkernel.cpp b/oneflow/user/kernels/stateful_opkernel.cpp index e83772c8f49..6835ed54de5 100644 --- a/oneflow/user/kernels/stateful_opkernel.cpp +++ b/oneflow/user/kernels/stateful_opkernel.cpp @@ -740,7 +740,6 @@ Maybe InitTensorTupleIndexes4Bns(const std::shared_ptr opkernel->stream_ = stream; opkernel->input_arg_tuple_ = input_arg_tuple; opkernel->output_arg_tuple_ = output_arg_tuple; - opkernel->need_check_mem_case_ = true; const DeviceType device_type = CHECK_JUST(DeviceType4DeviceTag(op_conf->device_tag())); const user_op::UserOpConfWrapper* user_op_conf = opkernel->user_op_conf_.get(); diff --git a/oneflow/user/kernels/stateful_opkernel.h b/oneflow/user/kernels/stateful_opkernel.h index 91eb58a326f..cfcb0477de2 100644 --- a/oneflow/user/kernels/stateful_opkernel.h +++ b/oneflow/user/kernels/stateful_opkernel.h @@ -75,8 +75,6 @@ class StatefulOpKernel final { size_t InferTmpSize(eager::CallContext* call_ctx, const user_op::OpKernel* user_opkernel) const; - void set_need_check_mem_case(bool value) { need_check_mem_case_ = value; } - Maybe ChooseOpKernel(eager::CallContext* call_ctx, const user_op::OpKernel** user_opkernel, bool* need_temp_storage); @@ -101,8 +99,6 @@ class StatefulOpKernel final { return op_kernel_state_map_.at(opkernel).get(); } - bool need_check_mem_case() const { return need_check_mem_case_; } - const user_op::InferTmpSizeFn& GetInferTmpSizeFn(const user_op::OpKernel* op_kernel) const; std::shared_ptr op_conf_; @@ -115,7 +111,6 @@ class StatefulOpKernel final { std::unique_ptr compute_ctx_helper_; std::shared_ptr input_arg_tuple_; std::shared_ptr output_arg_tuple_; - bool need_check_mem_case_; user_op::TensorDescInferFn tensor_desc_infer_fn_; user_op::DataTypeInferFn data_type_infer_fn_; // NOTE: every device has its own stateful local opkernel instance, From fd288f75c4e4f1441aa2fd7538001104ecc3d3c1 Mon Sep 17 00:00:00 2001 From: lixinqi Date: Sun, 10 Jul 2022 19:04:51 +0800 Subject: [PATCH 10/67] remove EagerLocalTensorImpl::is_shape_synced_ --- oneflow/core/eager/eager_blob_object.cpp | 1 - oneflow/core/eager/eager_blob_object.h | 5 ----- .../op_interpreter/eager_local_op_interpreter.cpp | 13 ++++++++++--- oneflow/core/framework/tensor_impl.cpp | 10 ---------- oneflow/core/framework/tensor_methods.cpp | 1 - 5 files changed, 10 insertions(+), 20 deletions(-) diff --git a/oneflow/core/eager/eager_blob_object.cpp b/oneflow/core/eager/eager_blob_object.cpp index f2fc0dbd204..2d7a1b8bb20 100644 --- a/oneflow/core/eager/eager_blob_object.cpp +++ b/oneflow/core/eager/eager_blob_object.cpp @@ -36,7 +36,6 @@ EagerBlobObject::EagerBlobObject(const std::shared_ptr& mem_case, tensor_storage_(tensor_storage), mem_ptr_for_allocation_compuation_pipelining_(nullptr), inited_mem_ptr_for_allocation_compuation_pipelining_(false), - is_shape_synced_(true), compute_local_dep_object_(dep_object), blob_desc_(shape, stride, data_type) { CHECK(static_cast(shape)); diff --git a/oneflow/core/eager/eager_blob_object.h b/oneflow/core/eager/eager_blob_object.h index 45e6569d3d4..b643c2deff9 100644 --- a/oneflow/core/eager/eager_blob_object.h +++ b/oneflow/core/eager/eager_blob_object.h @@ -150,10 +150,6 @@ class EagerBlobObject final : public user_op::Tensor, std::shared_ptr& tensor_storage() { return tensor_storage_; } - bool is_shape_synced() const { return is_shape_synced_; } - - void set_is_shape_synced(bool val) { is_shape_synced_ = val; } - const Optional>& producer_stream() const { return tensor_storage_->producer_stream(); } @@ -213,7 +209,6 @@ class EagerBlobObject final : public user_op::Tensor, // are kept even after tensor_storage_.reset(). char* mem_ptr_for_allocation_compuation_pipelining_; bool inited_mem_ptr_for_allocation_compuation_pipelining_; - std::atomic is_shape_synced_; bool pin_memory_; intrusive::shared_ptr compute_local_dep_object_; diff --git a/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp b/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp index 486c69131ef..26ce9883858 100644 --- a/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp +++ b/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp @@ -31,6 +31,7 @@ limitations under the License. #include "oneflow/core/operator/operator.h" #include "oneflow/user/kernels/stateful_opkernel.h" #include "oneflow/core/vm/vm_util.h" +#include "oneflow/core/vm/virtual_machine.h" #include "oneflow/core/autograd/autograd_mode.h" #include "oneflow/core/framework/placement_sbp_util.h" #include "oneflow/core/framework/tensor_rpc_util.h" @@ -177,15 +178,21 @@ Maybe NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& in OF_PROFILER_RANGE_PUSH("init opkernel"); const auto& kernel = JUST(user_op_expr.MutKernel4Stream(stream)); - for (int64_t index : kernel->output_tuple_indexes4mut2_obns()) { - output_eager_blob_objects.at(index)->set_is_shape_synced(false); - } OF_PROFILER_RANGE_POP(); OF_PROFILER_RANGE_PUSH("PhysicalRun"); JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe { return builder->Call(kernel, std::move(input_eager_blob_objects), std::move(output_eager_blob_objects), ctx, stream); })); + for (int64_t index : kernel->output_tuple_indexes4mut2_obns()) { + const auto* tensor_impl = JUST(TensorImpl4Tensor(outputs->at(index))); + auto btb = std::make_shared(1); + JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe { + return builder->SyncAccessBlobByCallback( + tensor_impl, btb, [](uint64_t) {}, "const"); + })); + JUST(btb->WaitUntilCntEqualZero(VirtualMachine::GetPredicatorNoMoreInstructionsFinished())); + } OF_PROFILER_RANGE_POP(); return Maybe::Ok(); } diff --git a/oneflow/core/framework/tensor_impl.cpp b/oneflow/core/framework/tensor_impl.cpp index aeb96f554e5..1cc52da6235 100644 --- a/oneflow/core/framework/tensor_impl.cpp +++ b/oneflow/core/framework/tensor_impl.cpp @@ -139,16 +139,6 @@ Maybe EagerLocalTensorImpl::set_eager_blob_object( std::shared_ptr EagerLocalTensorImpl::shape() const { if (!eager_blob_object_) { return tensor_meta()->shape_ptr(); } - if (!eager_blob_object_->is_shape_synced()) { - auto btb = std::make_shared(1); - CHECK_JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe { - return builder->SyncAccessBlobByCallback( - this, btb, [](uint64_t) {}, "const"); - })); - TRY(btb->WaitUntilCntEqualZero(VirtualMachine::GetPredicatorNoMoreInstructionsFinished())) - .GetOrThrow(); - eager_blob_object_->set_is_shape_synced(true); - } return eager_blob_object_->shape_ptr(); } diff --git a/oneflow/core/framework/tensor_methods.cpp b/oneflow/core/framework/tensor_methods.cpp index 1ee4aa6829d..8d3ebc842ad 100644 --- a/oneflow/core/framework/tensor_methods.cpp +++ b/oneflow/core/framework/tensor_methods.cpp @@ -82,7 +82,6 @@ Maybe BasicView(const std::shared_ptr& input, const Shape& targe const std::shared_ptr& view_eager_blob_object = JUST(view_tensor->eager_blob_object()); view_eager_blob_object->set_storage_offset(JUST(view_tensor->storage_offset())); - view_eager_blob_object->set_is_shape_synced(true); return std::static_pointer_cast(view_tensor); } From ca363210e556c40a7b0ec9e6d1defca30f3b4362 Mon Sep 17 00:00:00 2001 From: clackhan Date: Mon, 11 Jul 2022 16:03:15 +0800 Subject: [PATCH 11/67] eager_local_interpreter_with_infer_cache --- .../framework/local_tensor_infer_cache.cpp | 209 ++++++++++++++++++ .../core/framework/local_tensor_infer_cache.h | 126 +++++++++++ oneflow/core/framework/op_expr.cpp | 2 + oneflow/core/framework/op_expr.h | 5 + .../eager_local_op_interpreter.cpp | 120 ++++------ oneflow/core/framework/tensor_meta.h | 7 + 6 files changed, 394 insertions(+), 75 deletions(-) create mode 100644 oneflow/core/framework/local_tensor_infer_cache.cpp create mode 100644 oneflow/core/framework/local_tensor_infer_cache.h diff --git a/oneflow/core/framework/local_tensor_infer_cache.cpp b/oneflow/core/framework/local_tensor_infer_cache.cpp new file mode 100644 index 00000000000..1f9f8711828 --- /dev/null +++ b/oneflow/core/framework/local_tensor_infer_cache.cpp @@ -0,0 +1,209 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/local_tensor_infer_cache.h" +#include "oneflow/core/framework/tensor_tuple.h" +#include "oneflow/core/framework/tensor.h" +#include "oneflow/core/operator/operator.h" +#include "oneflow/core/framework/op_expr.h" +#include "oneflow/core/common/container_util.h" +#include "oneflow/core/framework/infer_util.h" + +namespace oneflow { +namespace one { + +namespace { + +Maybe CheckIsDeviceSupportedByOp(const Device& device, const std::string& op_type_name) { + if (IsCpuOnly(op_type_name)) { CHECK_EQ_OR_RETURN(device.type(), "cpu"); } + return Maybe::Ok(); +} + +Maybe CheckInputDeviceIdentical(const LocalTensorMetaInferArgs& infer_args, + Symbol default_device) { + if (infer_args.input_local_tensor_metas().empty()) { return Maybe::Ok(); } + for (int i = 0; i < infer_args.input_local_tensor_metas().size(); ++i) { + CHECK_OR_RETURN(default_device + == JUST(VectorAt(infer_args.input_local_tensor_metas(), i))->device()) + << Error::RuntimeError() + << "Expected all tensors to be on the same device, but found " + "at least two devices, " + << default_device->ToString() << " (positional 0) and " + << JUST(VectorAt(infer_args.input_local_tensor_metas(), i))->device()->ToString() + << " (positional " << i << ")!"; + } + return Maybe::Ok(); +} + +class UserOpExprDeviceAndStreamInferContext final : public user_op::DeviceAndStreamInferContext { + public: + UserOpExprDeviceAndStreamInferContext(const UserOpExpr* user_op_expr, + const LocalTensorMetaInferArgs* infer_args, + std::vector* output_tensor_metas) + : user_op_expr_(user_op_expr), + composed_attrs_(infer_args->attrs(), user_op_expr->base_attrs()), + in_tensor_devices_(user_op_expr_->input_size()), + out_tensor_devices_(user_op_expr_->output_size()) { + for (int i = 0; i < user_op_expr_->input_size(); ++i) { + const auto& device = infer_args->input_local_tensor_metas().at(i)->device(); + in_tensor_devices_.at(i) = device; + } + for (int i = 0; i < user_op_expr_->output_size(); ++i) { + out_tensor_devices_.at(i) = output_tensor_metas->at(i).mut_device(); + ; + } + } + + const std::vector>& inputs() const override { + return user_op_expr_->indexed_input_pairs(); + } + + const std::vector>& outputs() const override { + return user_op_expr_->indexed_output_pairs(); + } + + Symbol* OutputTensorDevice4ArgNameAndIndex(const std::string& name, + int64_t index) override { + const auto& arg_tuple = *user_op_expr_->output_arg_tuple(); + int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); + CHECK_GE(tuple_index, 0); + CHECK_LT(tuple_index, user_op_expr_->output_size()); + return out_tensor_devices_.at(tuple_index); + } + + Symbol InputTensorDevice4ArgNameAndIndex(const std::string& name, + int64_t index) const override { + const auto& arg_tuple = *user_op_expr_->input_arg_tuple(); + int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); + CHECK_GE(tuple_index, 0); + CHECK_LT(tuple_index, user_op_expr_->input_size()); + return in_tensor_devices_.at(tuple_index); + } + + private: + const std::shared_ptr& Attr4Name( + const std::string& attr_name) const override { + return composed_attrs_.Attr4Name(attr_name); + } + const UserOpExpr* user_op_expr_; + const ComposedAttrMap composed_attrs_; + std::vector> in_tensor_devices_; + std::vector*> out_tensor_devices_; +}; + +Maybe> InferDeviceAndStream(const UserOpExpr& user_op_expr, + const LocalTensorMetaInferArgs& infer_args, + std::vector* output_tensor_metas) { + if (!user_op_expr.device_and_stream_infer_fn()) { + Symbol device = infer_args.input_local_tensor_metas().at(0)->device(); + return GetDefaultStreamByDevice(device); + } else { + UserOpExprDeviceAndStreamInferContext device_and_stream_ctx(&user_op_expr, &infer_args, + output_tensor_metas); + return TRY(user_op_expr.device_and_stream_infer_fn()(&device_and_stream_ctx)); + } +} + +} // namespace + +size_t LocalTensorMetaInferArgs::hash_value() const { + size_t hash_value = std::hash()(attrs_); + HashCombine(&hash_value, std::hash>()(default_device_)); + const auto& tensor_meta_hash_functor = std::hash(); + for (const auto& tensor_meta : input_local_tensor_metas_) { + HashCombine(&hash_value, tensor_meta_hash_functor(*tensor_meta)); + } + return hash_value; +} + +bool LocalTensorMetaInferArgs::operator==(const LocalTensorMetaInferArgs& other) const { + return this->attrs_ == other.attrs_ && this->default_device_ == other.default_device_ + && this->input_local_tensor_metas_ == other.input_local_tensor_metas_; +} + +Maybe LocalTensorMetaInferArgs::New(const AttrMap& attrs, + Symbol default_device, + const TensorTuple& input_tensors) { + std::shared_ptr infer_args(new LocalTensorMetaInferArgs()); + infer_args->attrs_ = attrs; + infer_args->default_device_ = default_device; + infer_args->input_local_tensor_metas_.resize(input_tensors.size()); + JUST(infer_args->InitInputLocalTensorMetas(input_tensors)); + return infer_args; +} + +Maybe LocalTensorMetaInferArgs::InitInputLocalTensorMetas(const TensorTuple& input_tensors) { + for (int i = 0; i < input_tensors.size(); ++i) { + LocalTensorMeta* local_tensor_meta = + dynamic_cast(input_tensors.at(i)->mut_tensor_meta()); + CHECK_NOTNULL_OR_RETURN(local_tensor_meta); + input_local_tensor_metas_.at(i) = SymbolOf(*local_tensor_meta); + } + return Maybe::Ok(); +} + +/* static */ Maybe LocalTensorInferCache::Infer( + const UserOpExpr& user_op_expr, const LocalTensorMetaInferArgs& infer_args) { + const auto& default_device = infer_args.default_device(); + JUST(CheckInputDeviceIdentical(infer_args, default_device)); + JUST(CheckIsDeviceSupportedByOp(*default_device, user_op_expr.op_type_name())); + + auto result = std::make_unique(user_op_expr.output_size()); + + std::vector output_mut_metas(user_op_expr.output_size()); + // Infer devices + Symbol stream; + if (!user_op_expr.has_device_and_stream_infer_fn()) { + stream = JUST(GetDefaultStreamByDevice(default_device)); + for (int i = 0; i < user_op_expr.output_size(); i++) { + auto& tensor_meta = output_mut_metas.at(i); + *tensor_meta.mut_device() = default_device; + } + } else { + stream = JUST(InferDeviceAndStream(user_op_expr, infer_args, &output_mut_metas)); + result->set_need_check_mem_case(false); + } + result->set_stream(stream); + + { + const auto& GetInputTensorMeta = [&](int32_t i) -> const TensorMeta* { + return infer_args.input_local_tensor_metas().at(i).shared_from_symbol().get(); + }; + JUST(user_op_expr.InferPhysicalTensorDesc( + infer_args.attrs(), stream->device()->type(), GetInputTensorMeta, + [&](int32_t i) -> TensorMeta* { return &output_mut_metas.at(i); })); + } + + auto* output_metas = result->mut_output_tensor_metas(); + for (int32_t i = 0; i < user_op_expr.output_size(); ++i) { + output_metas->at(i) = SymbolOf(output_mut_metas.at(i)); + } + return std::shared_ptr(std::move(result)); +} + +Maybe LocalTensorInferCache::GetOrInfer( + const LocalTensorMetaInferArgs& infer_args) { + auto iter = cache_.find(infer_args); + if (iter == cache_.end()) { + const auto& user_op_expr = user_op_expr_.lock(); + CHECK_OR_RETURN(static_cast(user_op_expr)); + const auto& output_tensor_metas = JUST(Infer(*user_op_expr, infer_args)); + iter = cache_.emplace(infer_args, output_tensor_metas).first; + } + return iter->second; +} + +} // namespace one +} // namespace oneflow diff --git a/oneflow/core/framework/local_tensor_infer_cache.h b/oneflow/core/framework/local_tensor_infer_cache.h new file mode 100644 index 00000000000..1745840c1cf --- /dev/null +++ b/oneflow/core/framework/local_tensor_infer_cache.h @@ -0,0 +1,126 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_CORE_FRAMEWORK_LOCAL_TENSOR_INFER_CACHE_H_ +#define ONEFLOW_CORE_FRAMEWORK_LOCAL_TENSOR_INFER_CACHE_H_ + +#include "oneflow/core/common/symbol.h" +#include "oneflow/core/common/maybe.h" +#include "oneflow/core/framework/attr_map.h" +#include "oneflow/core/framework/device.h" +#include "oneflow/core/framework/stream.h" +#include "oneflow/core/framework/tensor_meta.h" + +namespace oneflow { + +class Device; + +namespace one { + +class TensorTuple; +class UserOpExpr; + +class LocalTensorMetaInferArgs final { + public: + LocalTensorMetaInferArgs(const LocalTensorMetaInferArgs&) = default; + LocalTensorMetaInferArgs(LocalTensorMetaInferArgs&&) = default; + ~LocalTensorMetaInferArgs() = default; + + const std::vector>& input_local_tensor_metas() const { + return input_local_tensor_metas_; + } + const AttrMap& attrs() const { return attrs_; } + + const Symbol& default_device() const { return default_device_; } + + size_t hash_value() const; + + bool operator==(const LocalTensorMetaInferArgs& other) const; + + static Maybe New(const AttrMap& attrs, Symbol default_device, + const TensorTuple& input_tensors); + + private: + LocalTensorMetaInferArgs() = default; + Maybe InitInputLocalTensorMetas(const TensorTuple& input_tensors); + + AttrMap attrs_; + Symbol default_device_; + std::vector> input_local_tensor_metas_; +}; + +} // namespace one +} // namespace oneflow + +namespace std { + +template<> +struct hash final { + size_t operator()(const oneflow::one::LocalTensorMetaInferArgs& val) const { + return val.hash_value(); + } +}; + +} // namespace std + +namespace oneflow { +namespace one { + +class LocalTensorInferResult final { + public: + LocalTensorInferResult(size_t output_size) + : output_tensor_metas_(output_size), need_check_mem_case_(true) {} + LocalTensorInferResult(const LocalTensorInferResult&) = delete; + LocalTensorInferResult(LocalTensorInferResult&&) = delete; + ~LocalTensorInferResult() = default; + + const std::vector>& output_tensor_metas() const { + return output_tensor_metas_; + } + std::vector>* mut_output_tensor_metas() { return &output_tensor_metas_; } + + const Symbol& stream() const { return stream_; } + void set_stream(const Symbol& stream) { stream_ = stream; } + + bool need_check_mem_case() const { return need_check_mem_case_; } + void set_need_check_mem_case(bool need_check_mem_case) { + need_check_mem_case_ = need_check_mem_case; + } + + private: + std::vector> output_tensor_metas_; + Symbol stream_; + bool need_check_mem_case_; +}; + +class LocalTensorInferCache final { + public: + LocalTensorInferCache(const std::shared_ptr& user_op_expr) + : user_op_expr_(user_op_expr) {} + + Maybe GetOrInfer(const LocalTensorMetaInferArgs& infer_args); + + private: + static Maybe Infer(const UserOpExpr& user_op_expr, + const LocalTensorMetaInferArgs& infer_args); + + std::weak_ptr user_op_expr_; + HashMap> cache_; +}; + +} // namespace one +} // namespace oneflow + +#endif // ONEFLOW_CORE_FRAMEWORK_LOCAL_TENSOR_INFER_CACHE_H_ diff --git a/oneflow/core/framework/op_expr.cpp b/oneflow/core/framework/op_expr.cpp index 47c5a1d0d79..13113237061 100644 --- a/oneflow/core/framework/op_expr.cpp +++ b/oneflow/core/framework/op_expr.cpp @@ -22,6 +22,7 @@ limitations under the License. #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_interpreter/dispatch_frame.h" #include "oneflow/core/framework/user_op_registry_manager.h" +#include "oneflow/core/framework/local_tensor_infer_cache.h" #include "oneflow/core/framework/global_tensor_infer_cache.h" #include "oneflow/core/operator/op_conf.pb.h" #include "oneflow/user/kernels/stateful_opkernel.h" @@ -457,6 +458,7 @@ Maybe UserOpExpr::Init(const std::shared_ptr& self) { if (registry->device_and_stream_infer_fn) { device_and_stream_infer_fn_ = registry->device_and_stream_infer_fn; } + local_tensor_infer_cache_.reset(new LocalTensorInferCache(self)); global_tensor_infer_cache_.reset(new GlobalTensorInferCache(self)); return Maybe::Ok(); } diff --git a/oneflow/core/framework/op_expr.h b/oneflow/core/framework/op_expr.h index d2072249388..13a7a7a0a07 100644 --- a/oneflow/core/framework/op_expr.h +++ b/oneflow/core/framework/op_expr.h @@ -126,6 +126,7 @@ class BuiltinOpExprImpl : public BuiltinOpExpr { }; class StatefulOpKernel; +class LocalTensorInferCache; class GlobalTensorInferCache; class UserOpExpr final : public BuiltinOpExprImpl { @@ -159,6 +160,9 @@ class UserOpExpr final : public BuiltinOpExprImpl { const std::function& TensorMeta4OutputIndex) const; Maybe> InferDeviceAndStream(const AttrMap& attrs, const TensorTuple& inputs, TensorTuple* outputs) const; + LocalTensorInferCache* mut_local_tensor_infer_cache() const { + return local_tensor_infer_cache_.get(); + } GlobalTensorInferCache* mut_global_tensor_infer_cache() const { return global_tensor_infer_cache_.get(); } @@ -173,6 +177,7 @@ class UserOpExpr final : public BuiltinOpExprImpl { user_op::DataTypeInferFn dtype_infer_fn_; user_op::DeviceAndStreamInferFn device_and_stream_infer_fn_; mutable HashMap, std::shared_ptr> stream2kernel_; + std::shared_ptr local_tensor_infer_cache_; std::shared_ptr global_tensor_infer_cache_; }; diff --git a/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp b/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp index 5e038baac2a..c4d6d2bf5c3 100644 --- a/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp +++ b/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp @@ -26,6 +26,7 @@ limitations under the License. #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor_name_scope.h" #include "oneflow/core/framework/tensor_tuple.h" +#include "oneflow/core/framework/local_tensor_infer_cache.h" #include "oneflow/core/common/stride.h" #include "oneflow/core/memory/memory_case_util.h" #include "oneflow/core/operator/operator.h" @@ -39,15 +40,22 @@ limitations under the License. #include "oneflow/core/framework/id_util.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/rpc/include/global_process_ctx.h" +#include "oneflow/core/profiler/profiler.h" namespace oneflow { namespace one { namespace { -Maybe> GetDefaultDevice(const OpExprInterpContext& ctx) { - if (ctx.device.has_value()) { return JUST(ctx.device); } - return Device::New("cpu", 0); +Maybe> GetDefaultDevice(const TensorTuple& inputs, const OpExprInterpContext& ctx) { + if (inputs.empty()) { + if (ctx.device.has_value()) { + return JUST(ctx.device); + } else { + return Device::New("cpu", 0); + } + } + return JUST(inputs.at(0)->device()); } Maybe TensorImpl4Tensor(const std::shared_ptr& tensor) { @@ -84,70 +92,36 @@ std::vector* ThreadLocalDefaultOutputMutTensorMetas(int64_t size) { } // namespace Maybe NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& inputs, - const Symbol& default_device, TensorTuple* outputs, - const OpExprInterpContext& ctx) { - const auto& attrs = ctx.attrs; + TensorTuple* outputs, const OpExprInterpContext& ctx) { + CHECK_EQ_OR_RETURN(outputs->size(), user_op_expr.output_size()); + Symbol default_device = JUST(GetDefaultDevice(inputs, ctx)); + + std::shared_ptr infer_args = + JUST(LocalTensorMetaInferArgs::New(ctx.attrs, default_device, inputs)); + std::shared_ptr result = + JUST(user_op_expr.mut_local_tensor_infer_cache()->GetOrInfer(*infer_args)); + std::shared_ptr input_eager_blob_objects = std::make_shared(inputs.size()); - for (int i = 0; i < inputs.size(); i++) { - const auto& input_device = JUST(inputs.at(i)->device()); - if (i > 0) { - CHECK_OR_RETURN(*default_device == *input_device) - << Error::RuntimeError() - << "Expected all tensors to be on the same device, but found at least two devices, " - << default_device->ToString() << " (positional 0) and " << input_device->ToString() - << " (positional " << i << ")!"; + if (inputs.size() > 0) { + for (int i = 0; i < inputs.size(); i++) { + input_eager_blob_objects->at(i) = JUST(inputs.at(i)->eager_blob_object()); } - input_eager_blob_objects->at(i) = JUST(inputs.at(i)->eager_blob_object()); } + + const auto& output_tensor_metas = result->output_tensor_metas(); std::shared_ptr output_eager_blob_objects = std::make_shared(outputs->size()); - auto* output_tensor_metas = ThreadLocalDefaultOutputMutTensorMetas(outputs->size()); + for (int i = 0; i < outputs->size(); i++) { if (!outputs->at(i)) { - const auto& tensor_impl = std::make_shared(); - (*outputs)[i] = std::make_shared(tensor_impl); - output_tensor_metas->at(i) = tensor_impl->mut_tensor_meta(); - } else { - bool has_eager_blob_object = JUST(outputs->at(i)->has_eager_blob_object()); - CHECK_OR_RETURN(has_eager_blob_object); - output_eager_blob_objects->at(i) = JUST(outputs->at(i)->eager_blob_object()); - } - } - Symbol stream; - bool need_check_mem_case = true; - - // Infer devices - if (!user_op_expr.has_device_and_stream_infer_fn()) { - stream = JUST(GetDefaultStreamByDevice(default_device)); - for (int i = 0; i < outputs->size(); i++) { - auto* tensor_impl = JUST(TensorImpl4Tensor(outputs->at(i))); - *JUST(tensor_impl->mut_device()) = default_device; - } - } else { - need_check_mem_case = false; - stream = JUST(user_op_expr.InferDeviceAndStream(attrs, inputs, outputs)); - } - - // Infer shapes and dtypes - const auto& device_tag = stream->device()->type(); - JUST(user_op_expr.InferPhysicalTensorDesc( - attrs, device_tag, - [&](int32_t i) -> const TensorMeta* { - return CHECK_JUST(TensorImpl4Tensor(inputs[i]))->mut_tensor_meta(); - }, - [&](int32_t i) -> TensorMeta* { - // using thread_local TensorMeta pointer if inplace. - // using tensor_impl TensorMeta pointer if not inplace. - return output_tensor_metas->at(i); - })); - - for (int i = 0; i < output_eager_blob_objects->size(); i++) { - auto* tensor_impl = JUST(TensorImpl4Tensor(outputs->at(i))); - if (!output_eager_blob_objects->at(i)) { // NOTE: if op support stride(non-contiguous input), then output tensor's stride // should be inferred in InferLogicalTensorDesc. // otherwise, it will be set here(according to shape). + // Note: symbol.shared_from_symbol() cannot be used here because set_stride happens in the + // next step. + std::shared_ptr tensor_impl = std::make_shared( + std::make_shared(*output_tensor_metas.at(i)), false, false); if (!JUST(user_op_expr.SupportNonContiguous())) { std::shared_ptr stride(new Stride(*tensor_impl->shape())); tensor_impl->mut_tensor_meta()->set_stride(stride); @@ -155,45 +129,41 @@ Maybe NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& in const auto& dep_object = NewLocalDepObject(); JUST(tensor_impl->InitEagerBlobObject(dep_object)); output_eager_blob_objects->at(i) = JUST(tensor_impl->eager_blob_object()); + (*outputs)[i] = std::make_shared(tensor_impl); } else { + auto* tensor_impl = JUST(TensorImpl4Tensor(outputs->at(i))); // output i is inplaced. - // check thread_local TensorMeta and tensor_impl TensorMeta. - CHECK_OR_RETURN(tensor_impl->tensor_meta()->shape() == output_tensor_metas->at(i)->shape()); - // TODO:(thread_local TensorMeta set stride then check) + // check TensorMeta of infer result and TensorMeta of output i. + CHECK_OR_RETURN(tensor_impl->tensor_meta()->shape() == output_tensor_metas.at(i)->shape()); + CHECK_OR_RETURN(tensor_impl->tensor_meta()->dtype() == output_tensor_metas.at(i)->dtype()); + bool has_eager_blob_object = JUST(outputs->at(i)->has_eager_blob_object()); + CHECK_OR_RETURN(has_eager_blob_object); + output_eager_blob_objects->at(i) = JUST(outputs->at(i)->eager_blob_object()); + // TODO(zhaoluyang):(thread_local TensorMeta set stride then check) // CHECK_OR_RETURN(tensor_impl->tensor_meta()->stride() == // output_tensor_metas->at(i)->stride()); - CHECK_OR_RETURN(tensor_impl->tensor_meta()->dtype() == output_tensor_metas->at(i)->dtype()); } } - const auto& kernel = JUST(user_op_expr.MutKernel4Stream(stream)); - kernel->set_need_check_mem_case(need_check_mem_case); + const auto& kernel = JUST(user_op_expr.MutKernel4Stream(result->stream())); + kernel->set_need_check_mem_case(result->need_check_mem_case()); for (int64_t index : kernel->output_tuple_indexes4mut2_obns()) { output_eager_blob_objects->at(index)->set_is_shape_synced(false); } JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe { - return builder->Call(kernel, input_eager_blob_objects, output_eager_blob_objects, ctx, stream); + return builder->Call(kernel, input_eager_blob_objects, output_eager_blob_objects, ctx, + result->stream()); })); - return Maybe::Ok(); -} -static Maybe NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& inputs, - TensorTuple* outputs, const OpExprInterpContext& ctx) { - CHECK_EQ_OR_RETURN(outputs->size(), user_op_expr.output_size()); - Symbol default_device; - if (inputs.empty()) { - default_device = JUST(GetDefaultDevice(ctx)); - } else { - default_device = JUST(inputs.at(0)->device()); - } - return NaiveInterpret(user_op_expr, inputs, default_device, outputs, ctx); + return Maybe::Ok(); } Maybe EagerLocalInterpreter::ApplyImpl(const UserOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext& ctx) const { + OF_PROFILER_RANGE_GUARD("NaiveInterpret"); return NaiveInterpret(op_expr, inputs, outputs, ctx); } diff --git a/oneflow/core/framework/tensor_meta.h b/oneflow/core/framework/tensor_meta.h index a8de6998828..8511661e0fe 100644 --- a/oneflow/core/framework/tensor_meta.h +++ b/oneflow/core/framework/tensor_meta.h @@ -127,6 +127,13 @@ class GlobalTensorMeta : public TensorMeta { namespace std { +template<> +struct hash final { + size_t operator()(const oneflow::one::LocalTensorMeta& local_tensor_meta) const { + return local_tensor_meta.CalcHashValue(); + } +}; + template<> struct hash final { size_t operator()(const oneflow::one::GlobalTensorMeta& global_tensor_meta) const { From fb69d8b1cc0af2c45c4459fb743c8e7e458dbe99 Mon Sep 17 00:00:00 2001 From: clackhan Date: Mon, 11 Jul 2022 16:24:54 +0800 Subject: [PATCH 12/67] remove useless code --- .../eager_local_op_interpreter.cpp | 26 ------------------- 1 file changed, 26 deletions(-) diff --git a/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp b/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp index c4d6d2bf5c3..e1d71d28b64 100644 --- a/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp +++ b/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp @@ -63,32 +63,6 @@ Maybe TensorImpl4Tensor(const std::shared_ptr& te return tensor->mut_eager_local_tensor_impl(); } -class MutLocalTensorMeta : public TensorMeta { // NOLINT - public: - MutLocalTensorMeta() - : TensorMeta(std::make_shared(), std::make_shared(), - kInvalidDataType) {} - MutLocalTensorMeta(const MutLocalTensorMeta&) = default; - MutLocalTensorMeta(MutLocalTensorMeta&&) = default; - ~MutLocalTensorMeta() override = default; -}; - -std::vector* ThreadLocalDefaultOutputMutTensorMetas(int64_t size) { - static thread_local std::vector struct_vec; - static thread_local std::vector ptr_vec; - struct_vec.resize(size); - ptr_vec.resize(size); - if (size == 1) { - ptr_vec.at(0) = &struct_vec.at(0); // unfold loop - } else if (size == 2) { - ptr_vec.at(0) = &struct_vec.at(0); // unfold loop - ptr_vec.at(1) = &struct_vec.at(1); // unfold loop - } else { - for (int i = 0; i < size; ++i) { ptr_vec.at(i) = &struct_vec.at(i); } - } - return &ptr_vec; -} - } // namespace Maybe NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& inputs, From 7845311fb0b459d0f4fb5be443e4627013ef5a8f Mon Sep 17 00:00:00 2001 From: clackhan Date: Mon, 11 Jul 2022 19:03:51 +0800 Subject: [PATCH 13/67] reslove comments --- .../framework/local_tensor_infer_cache.cpp | 17 +++++++---------- .../core/framework/local_tensor_infer_cache.h | 15 ++++----------- .../eager_local_op_interpreter.cpp | 19 +++++++++---------- 3 files changed, 20 insertions(+), 31 deletions(-) diff --git a/oneflow/core/framework/local_tensor_infer_cache.cpp b/oneflow/core/framework/local_tensor_infer_cache.cpp index 1f9f8711828..4c2d43c9a13 100644 --- a/oneflow/core/framework/local_tensor_infer_cache.cpp +++ b/oneflow/core/framework/local_tensor_infer_cache.cpp @@ -133,15 +133,13 @@ bool LocalTensorMetaInferArgs::operator==(const LocalTensorMetaInferArgs& other) && this->input_local_tensor_metas_ == other.input_local_tensor_metas_; } -Maybe LocalTensorMetaInferArgs::New(const AttrMap& attrs, - Symbol default_device, - const TensorTuple& input_tensors) { - std::shared_ptr infer_args(new LocalTensorMetaInferArgs()); - infer_args->attrs_ = attrs; - infer_args->default_device_ = default_device; - infer_args->input_local_tensor_metas_.resize(input_tensors.size()); - JUST(infer_args->InitInputLocalTensorMetas(input_tensors)); - return infer_args; +Maybe LocalTensorMetaInferArgs::Init(const AttrMap& attrs, Symbol default_device, + const TensorTuple& input_tensors) { + this->attrs_ = attrs; + this->default_device_ = default_device; + this->input_local_tensor_metas_.resize(input_tensors.size()); + JUST(this->InitInputLocalTensorMetas(input_tensors)); + return Maybe::Ok(); } Maybe LocalTensorMetaInferArgs::InitInputLocalTensorMetas(const TensorTuple& input_tensors) { @@ -173,7 +171,6 @@ Maybe LocalTensorMetaInferArgs::InitInputLocalTensorMetas(const TensorTupl } } else { stream = JUST(InferDeviceAndStream(user_op_expr, infer_args, &output_mut_metas)); - result->set_need_check_mem_case(false); } result->set_stream(stream); diff --git a/oneflow/core/framework/local_tensor_infer_cache.h b/oneflow/core/framework/local_tensor_infer_cache.h index 1745840c1cf..8d763b08caf 100644 --- a/oneflow/core/framework/local_tensor_infer_cache.h +++ b/oneflow/core/framework/local_tensor_infer_cache.h @@ -34,6 +34,7 @@ class UserOpExpr; class LocalTensorMetaInferArgs final { public: + LocalTensorMetaInferArgs() = default; LocalTensorMetaInferArgs(const LocalTensorMetaInferArgs&) = default; LocalTensorMetaInferArgs(LocalTensorMetaInferArgs&&) = default; ~LocalTensorMetaInferArgs() = default; @@ -49,11 +50,10 @@ class LocalTensorMetaInferArgs final { bool operator==(const LocalTensorMetaInferArgs& other) const; - static Maybe New(const AttrMap& attrs, Symbol default_device, - const TensorTuple& input_tensors); + Maybe Init(const AttrMap& attrs, Symbol default_device, + const TensorTuple& input_tensors); private: - LocalTensorMetaInferArgs() = default; Maybe InitInputLocalTensorMetas(const TensorTuple& input_tensors); AttrMap attrs_; @@ -80,8 +80,7 @@ namespace one { class LocalTensorInferResult final { public: - LocalTensorInferResult(size_t output_size) - : output_tensor_metas_(output_size), need_check_mem_case_(true) {} + LocalTensorInferResult(size_t output_size) : output_tensor_metas_(output_size) {} LocalTensorInferResult(const LocalTensorInferResult&) = delete; LocalTensorInferResult(LocalTensorInferResult&&) = delete; ~LocalTensorInferResult() = default; @@ -94,15 +93,9 @@ class LocalTensorInferResult final { const Symbol& stream() const { return stream_; } void set_stream(const Symbol& stream) { stream_ = stream; } - bool need_check_mem_case() const { return need_check_mem_case_; } - void set_need_check_mem_case(bool need_check_mem_case) { - need_check_mem_case_ = need_check_mem_case; - } - private: std::vector> output_tensor_metas_; Symbol stream_; - bool need_check_mem_case_; }; class LocalTensorInferCache final { diff --git a/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp b/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp index 59c2eb39aad..8bd8759ecdc 100644 --- a/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp +++ b/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp @@ -68,19 +68,19 @@ Maybe TensorImpl4Tensor(const std::shared_ptr& te Maybe NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext& ctx) { + OF_PROFILER_RANGE_GUARD("NaiveInterpret"); CHECK_EQ_OR_RETURN(outputs->size(), user_op_expr.output_size()); Symbol default_device = JUST(GetDefaultDevice(inputs, ctx)); - std::shared_ptr infer_args = - JUST(LocalTensorMetaInferArgs::New(ctx.attrs, default_device, inputs)); - - std::shared_ptr result = - JUST(user_op_expr.mut_local_tensor_infer_cache()->GetOrInfer(*infer_args)); + std::shared_ptr result; + { + LocalTensorMetaInferArgs infer_args; + JUST(infer_args.Init(ctx.attrs, default_device, inputs)); + result = JUST(user_op_expr.mut_local_tensor_infer_cache()->GetOrInfer(infer_args)); + } vm::EagerBlobObjectList input_eager_blob_objects(inputs.size()); - if (inputs.size() > 0) { - for (int i = 0; i < inputs.size(); i++) { - input_eager_blob_objects.at(i) = JUST(inputs.at(i)->eager_blob_object()); - } + for (int i = 0; i < inputs.size(); i++) { + input_eager_blob_objects.at(i) = JUST(inputs.at(i)->eager_blob_object()); } const auto& output_tensor_metas = result->output_tensor_metas(); @@ -142,7 +142,6 @@ Maybe NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& in Maybe EagerLocalInterpreter::ApplyImpl(const UserOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext& ctx) const { - OF_PROFILER_RANGE_GUARD("NaiveInterpret"); return NaiveInterpret(op_expr, inputs, outputs, ctx); } From cb2b22f8534c500a98a7510151dc33d3514612ac Mon Sep 17 00:00:00 2001 From: lixinqi Date: Mon, 11 Jul 2022 21:38:18 +0800 Subject: [PATCH 14/67] refactor TensorMeta::TensorMeta(const TensorMeta) --- oneflow/core/framework/tensor_meta.h | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/oneflow/core/framework/tensor_meta.h b/oneflow/core/framework/tensor_meta.h index 8511661e0fe..df8bd82f815 100644 --- a/oneflow/core/framework/tensor_meta.h +++ b/oneflow/core/framework/tensor_meta.h @@ -42,7 +42,11 @@ class TensorMeta : public user_op::TensorDesc { TensorMeta(const std::shared_ptr& shape, const std::shared_ptr& stride, DataType dtype) : shape_(shape), stride_(stride), data_type_(dtype), is_dynamic_(false) {} - TensorMeta(const TensorMeta&) = default; + TensorMeta(const TensorMeta& other) + : shape_(std::make_shared(*other.shape_)), + stride_(std::make_shared(*other.stride_)), + data_type_(other.data_type_), + is_dynamic_(other.is_dynamic_) {} TensorMeta(TensorMeta&&) = default; virtual ~TensorMeta() = default; @@ -77,6 +81,7 @@ class LocalTensorMeta : public TensorMeta { public: // uninitialized LocalTensorMeta. LocalTensorMeta(); + LocalTensorMeta(const LocalTensorMeta&) = default; LocalTensorMeta(const std::shared_ptr& shape, DataType dtype, Symbol device); LocalTensorMeta(const std::shared_ptr& shape, const std::shared_ptr& stride, DataType dtype, From c3dccba65be3c1ddbf13589a267c4ec2908c4a00 Mon Sep 17 00:00:00 2001 From: clackhan Date: Wed, 13 Jul 2022 20:41:08 +0800 Subject: [PATCH 15/67] use small vector --- oneflow/core/common/stride.cpp | 3 +- .../framework/local_tensor_infer_cache.cpp | 50 ++++++++++++------- .../core/framework/local_tensor_infer_cache.h | 17 ++++--- 3 files changed, 46 insertions(+), 24 deletions(-) diff --git a/oneflow/core/common/stride.cpp b/oneflow/core/common/stride.cpp index 38552a832f9..ab130076065 100644 --- a/oneflow/core/common/stride.cpp +++ b/oneflow/core/common/stride.cpp @@ -15,6 +15,7 @@ limitations under the License. */ #include "oneflow/core/common/stride.h" +#include "oneflow/core/common/constant.h" #include "oneflow/core/common/protobuf.h" #include "oneflow/core/common/cplusplus_17.h" @@ -29,7 +30,7 @@ Stride::Stride(const Shape& shape) { std::multiplies<>{}); } else if (ndim > 0 && shape.elem_cnt() == 0) { // 0-size shape - std::vector tmp_shape(ndim); + small_vector tmp_shape(ndim); for (int64_t i = 0; i < ndim; ++i) { tmp_shape[i] = shape.At(i) > 0 ? shape.At(i) : 1; } std::exclusive_scan(tmp_shape.rbegin(), tmp_shape.rend(), rbegin(), (int64_t)1, std::multiplies<>{}); diff --git a/oneflow/core/framework/local_tensor_infer_cache.cpp b/oneflow/core/framework/local_tensor_infer_cache.cpp index 4c2d43c9a13..013958d1f30 100644 --- a/oneflow/core/framework/local_tensor_infer_cache.cpp +++ b/oneflow/core/framework/local_tensor_infer_cache.cpp @@ -19,6 +19,7 @@ limitations under the License. #include "oneflow/core/operator/operator.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/common/container_util.h" +#include "oneflow/core/common/util.h" #include "oneflow/core/framework/infer_util.h" namespace oneflow { @@ -26,6 +27,12 @@ namespace one { namespace { +// NOTE: use env variable 'ONEFLOW_EAGER_ENABLE_LOCAL_INFER_CACHE' indicate whether the +// use infer cache in naive local op interpret. +bool ParseEnableEagerLocalInferCache() { + return ParseBooleanFromEnv("ONEFLOW_EAGER_ENABLE_LOCAL_INFER_CACHE", true); +} + Maybe CheckIsDeviceSupportedByOp(const Device& device, const std::string& op_type_name) { if (IsCpuOnly(op_type_name)) { CHECK_EQ_OR_RETURN(device.type(), "cpu"); } return Maybe::Ok(); @@ -51,7 +58,7 @@ class UserOpExprDeviceAndStreamInferContext final : public user_op::DeviceAndStr public: UserOpExprDeviceAndStreamInferContext(const UserOpExpr* user_op_expr, const LocalTensorMetaInferArgs* infer_args, - std::vector* output_tensor_metas) + std::vector* output_tensor_metas) : user_op_expr_(user_op_expr), composed_attrs_(infer_args->attrs(), user_op_expr->base_attrs()), in_tensor_devices_(user_op_expr_->input_size()), @@ -62,7 +69,6 @@ class UserOpExprDeviceAndStreamInferContext final : public user_op::DeviceAndStr } for (int i = 0; i < user_op_expr_->output_size(); ++i) { out_tensor_devices_.at(i) = output_tensor_metas->at(i).mut_device(); - ; } } @@ -99,13 +105,13 @@ class UserOpExprDeviceAndStreamInferContext final : public user_op::DeviceAndStr } const UserOpExpr* user_op_expr_; const ComposedAttrMap composed_attrs_; - std::vector> in_tensor_devices_; - std::vector*> out_tensor_devices_; + small_vector, kOpArgsReservedSize> in_tensor_devices_; + small_vector*, kOpArgsReservedSize> out_tensor_devices_; }; Maybe> InferDeviceAndStream(const UserOpExpr& user_op_expr, const LocalTensorMetaInferArgs& infer_args, - std::vector* output_tensor_metas) { + std::vector* output_tensor_metas) { if (!user_op_expr.device_and_stream_infer_fn()) { Symbol device = infer_args.input_local_tensor_metas().at(0)->device(); return GetDefaultStreamByDevice(device); @@ -144,10 +150,7 @@ Maybe LocalTensorMetaInferArgs::Init(const AttrMap& attrs, Symbol Maybe LocalTensorMetaInferArgs::InitInputLocalTensorMetas(const TensorTuple& input_tensors) { for (int i = 0; i < input_tensors.size(); ++i) { - LocalTensorMeta* local_tensor_meta = - dynamic_cast(input_tensors.at(i)->mut_tensor_meta()); - CHECK_NOTNULL_OR_RETURN(local_tensor_meta); - input_local_tensor_metas_.at(i) = SymbolOf(*local_tensor_meta); + input_local_tensor_metas_.at(i) = JUST(input_tensors.at(i)->local_tensor_meta()); } return Maybe::Ok(); } @@ -160,7 +163,7 @@ Maybe LocalTensorMetaInferArgs::InitInputLocalTensorMetas(const TensorTupl auto result = std::make_unique(user_op_expr.output_size()); - std::vector output_mut_metas(user_op_expr.output_size()); + std::vector output_mut_metas(user_op_expr.output_size()); // Infer devices Symbol stream; if (!user_op_expr.has_device_and_stream_infer_fn()) { @@ -185,21 +188,34 @@ Maybe LocalTensorMetaInferArgs::InitInputLocalTensorMetas(const TensorTupl auto* output_metas = result->mut_output_tensor_metas(); for (int32_t i = 0; i < user_op_expr.output_size(); ++i) { - output_metas->at(i) = SymbolOf(output_mut_metas.at(i)); + if (!JUST(user_op_expr.SupportNonContiguous())) { + std::shared_ptr stride(new Stride(output_mut_metas.at(i).shape())); + output_mut_metas.at(i).set_stride(stride); + } + output_metas->at(i) = SymbolOf( + LocalTensorMeta(output_mut_metas.at(i).shape_ptr(), output_mut_metas.at(i).stride_ptr(), + output_mut_metas.at(i).data_type(), output_mut_metas.at(i).device(), + output_mut_metas.at(i).storage_offset())); } return std::shared_ptr(std::move(result)); } Maybe LocalTensorInferCache::GetOrInfer( const LocalTensorMetaInferArgs& infer_args) { - auto iter = cache_.find(infer_args); - if (iter == cache_.end()) { + static bool enable_eager_local_infer_cache = ParseEnableEagerLocalInferCache(); + if (enable_eager_local_infer_cache) { + auto iter = cache_.find(infer_args); + if (iter == cache_.end()) { + const auto& user_op_expr = user_op_expr_.lock(); + CHECK_OR_RETURN(static_cast(user_op_expr)); + const auto& output_tensor_metas = JUST(Infer(*user_op_expr, infer_args)); + iter = cache_.emplace(infer_args, output_tensor_metas).first; + } + return iter->second; + } else { const auto& user_op_expr = user_op_expr_.lock(); - CHECK_OR_RETURN(static_cast(user_op_expr)); - const auto& output_tensor_metas = JUST(Infer(*user_op_expr, infer_args)); - iter = cache_.emplace(infer_args, output_tensor_metas).first; + return JUST(Infer(*user_op_expr, infer_args)); } - return iter->second; } } // namespace one diff --git a/oneflow/core/framework/local_tensor_infer_cache.h b/oneflow/core/framework/local_tensor_infer_cache.h index 8d763b08caf..055a5edb4a1 100644 --- a/oneflow/core/framework/local_tensor_infer_cache.h +++ b/oneflow/core/framework/local_tensor_infer_cache.h @@ -18,10 +18,12 @@ limitations under the License. #include "oneflow/core/common/symbol.h" #include "oneflow/core/common/maybe.h" +#include "oneflow/core/common/small_vector.h" +#include "oneflow/core/common/op_args_reserved_size.h" #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/framework/stream.h" -#include "oneflow/core/framework/tensor_meta.h" +#include "oneflow/core/common/tensor_meta.h" namespace oneflow { @@ -39,7 +41,8 @@ class LocalTensorMetaInferArgs final { LocalTensorMetaInferArgs(LocalTensorMetaInferArgs&&) = default; ~LocalTensorMetaInferArgs() = default; - const std::vector>& input_local_tensor_metas() const { + const small_vector, kOpArgsReservedSize>& input_local_tensor_metas() + const { return input_local_tensor_metas_; } const AttrMap& attrs() const { return attrs_; } @@ -58,7 +61,7 @@ class LocalTensorMetaInferArgs final { AttrMap attrs_; Symbol default_device_; - std::vector> input_local_tensor_metas_; + small_vector, kOpArgsReservedSize> input_local_tensor_metas_; }; } // namespace one @@ -85,16 +88,18 @@ class LocalTensorInferResult final { LocalTensorInferResult(LocalTensorInferResult&&) = delete; ~LocalTensorInferResult() = default; - const std::vector>& output_tensor_metas() const { + const small_vector, kOpArgsReservedSize>& output_tensor_metas() const { return output_tensor_metas_; } - std::vector>* mut_output_tensor_metas() { return &output_tensor_metas_; } + small_vector, kOpArgsReservedSize>* mut_output_tensor_metas() { + return &output_tensor_metas_; + } const Symbol& stream() const { return stream_; } void set_stream(const Symbol& stream) { stream_ = stream; } private: - std::vector> output_tensor_metas_; + small_vector, kOpArgsReservedSize> output_tensor_metas_; Symbol stream_; }; From c0e643fb66a23721253e95461e830cc47b020f01 Mon Sep 17 00:00:00 2001 From: clackhan Date: Wed, 13 Jul 2022 21:31:52 +0800 Subject: [PATCH 16/67] Symbolic LocalTensorMeta --- oneflow/api/python/functional/tensor_api.cpp | 6 +- oneflow/core/common/constant.h | 1 + .../{framework => common}/tensor_desc.cpp | 3 +- .../core/{framework => common}/tensor_desc.h | 10 +-- .../{framework => common}/tensor_meta.cpp | 32 ++++++++- .../core/{framework => common}/tensor_meta.h | 34 ++++++++-- oneflow/core/eager/eager_blob_object.cpp | 66 ++++++++++++++++--- oneflow/core/eager/eager_blob_object.h | 61 ++++++++++------- .../core/eager/op_call_phy_instr_operand.cpp | 1 - oneflow/core/framework/consistency_check.h | 2 +- oneflow/core/framework/framework.h | 2 +- .../framework/global_tensor_infer_cache.h | 2 +- oneflow/core/framework/infer_util.h | 2 +- .../eager_local_op_interpreter.cpp | 40 +++++++---- oneflow/core/framework/placement_sbp_util.cpp | 2 +- .../framework/placement_sbp_util_test.cpp | 2 +- .../sync_symbol_global_tensor_meta.cpp | 2 +- oneflow/core/framework/tensor.cpp | 6 +- oneflow/core/framework/tensor.h | 18 ++++- oneflow/core/framework/tensor_impl.cpp | 53 +++++++-------- oneflow/core/framework/tensor_impl.h | 64 +++++++++++------- oneflow/core/framework/tensor_methods.cpp | 15 +++-- oneflow/core/framework/user_op_conf.h | 2 +- .../framework/user_op_registry_manager.cpp | 2 +- .../core/functional/impl/array_functor.cpp | 18 +++-- oneflow/core/operator/user_op.cpp | 2 +- oneflow/user/kernels/stateful_opkernel.cpp | 11 ++-- oneflow/user/kernels/stateful_opkernel.h | 7 +- 28 files changed, 328 insertions(+), 138 deletions(-) rename oneflow/core/{framework => common}/tensor_desc.cpp (94%) rename oneflow/core/{framework => common}/tensor_desc.h (91%) rename oneflow/core/{framework => common}/tensor_meta.cpp (69%) rename oneflow/core/{framework => common}/tensor_meta.h (81%) diff --git a/oneflow/api/python/functional/tensor_api.cpp b/oneflow/api/python/functional/tensor_api.cpp index 7496995bcbc..c3bf8ca90dd 100644 --- a/oneflow/api/python/functional/tensor_api.cpp +++ b/oneflow/api/python/functional/tensor_api.cpp @@ -266,7 +266,7 @@ class LocalTensorSharedNumpyDataFunctor { } stride_val /= element_size_in_bytes; } - auto tensor_meta = std::make_shared(shape, strides, data_type, device, 0); + auto tensor_meta = SymbolOf(LocalTensorMeta(shape, strides, data_type, device, 0)); // Build TensorBuffer const auto& Free = [array](char* dptr) { @@ -286,12 +286,12 @@ class LocalTensorSharedNumpyDataFunctor { auto tensor_storage = std::make_shared(tensor_data); // Build Tensor - auto tensor_impl = std::make_shared(tensor_meta, tensor_storage, + auto tensor_impl = std::make_shared(tensor_storage, /*requires_grad=*/false, /*ls_leaf=*/true); // Init blob - JUST(tensor_impl->InitEagerBlobObject(NewLocalDepObject())); + JUST(tensor_impl->InitEagerBlobObject(tensor_meta, NewLocalDepObject())); const auto& stream = JUST(GetDefaultStreamByDevice(device)); const auto& eager_blob_object = JUST(tensor_impl->eager_blob_object()); JUST(eager_blob_object->init_producer_stream(stream)); diff --git a/oneflow/core/common/constant.h b/oneflow/core/common/constant.h index 3f8b331bdb4..7760e161128 100644 --- a/oneflow/core/common/constant.h +++ b/oneflow/core/common/constant.h @@ -24,6 +24,7 @@ static const int64_t kInvalidSessionId = -1; static const std::string kNoPassTag = ""; static const std::string kMainOp = "main_op"; static const int64_t kMaxSplitAxis = 6; +constexpr size_t kMaxNumDims = 8; static const std::string kAsymmetricCodeErrorMsg = "Maybe executing different code in different ranks, please check if the code is branched and " "operates on the global tensor."; diff --git a/oneflow/core/framework/tensor_desc.cpp b/oneflow/core/common/tensor_desc.cpp similarity index 94% rename from oneflow/core/framework/tensor_desc.cpp rename to oneflow/core/common/tensor_desc.cpp index b13dd5dac39..ed82fe40dbe 100644 --- a/oneflow/core/framework/tensor_desc.cpp +++ b/oneflow/core/common/tensor_desc.cpp @@ -13,7 +13,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "oneflow/core/framework/tensor_desc.h" +#include "oneflow/core/common/tensor_desc.h" +#include "oneflow/core/register/blob_desc.pb.h" namespace oneflow { diff --git a/oneflow/core/framework/tensor_desc.h b/oneflow/core/common/tensor_desc.h similarity index 91% rename from oneflow/core/framework/tensor_desc.h rename to oneflow/core/common/tensor_desc.h index c22e92aa12a..fa1dbf7fe22 100644 --- a/oneflow/core/framework/tensor_desc.h +++ b/oneflow/core/common/tensor_desc.h @@ -13,16 +13,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifndef ONEFLOW_CORE_FRAMEWORK_TENSOR_DESC_H_ -#define ONEFLOW_CORE_FRAMEWORK_TENSOR_DESC_H_ +#ifndef ONEFLOW_CORE_COMMON_TENSOR_DESC_H_ +#define ONEFLOW_CORE_COMMON_TENSOR_DESC_H_ #include "oneflow/core/common/util.h" -#include "oneflow/core/register/blob_desc.pb.h" #include "oneflow/core/common/shape.h" #include "oneflow/core/common/stride.h" +#include "oneflow/core/common/data_type.pb.h" namespace oneflow { +class BlobDescProto; + namespace user_op { class TensorDesc { @@ -77,4 +79,4 @@ class NaiveTensorDesc final : public TensorDesc { } // namespace oneflow -#endif // ONEFLOW_CORE_FRAMEWORK_TENSOR_DESC_H_ +#endif // ONEFLOW_CORE_COMMON_TENSOR_DESC_H_ diff --git a/oneflow/core/framework/tensor_meta.cpp b/oneflow/core/common/tensor_meta.cpp similarity index 69% rename from oneflow/core/framework/tensor_meta.cpp rename to oneflow/core/common/tensor_meta.cpp index 7eb481f6600..06beb5d4262 100644 --- a/oneflow/core/framework/tensor_meta.cpp +++ b/oneflow/core/common/tensor_meta.cpp @@ -13,7 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "oneflow/core/framework/tensor_meta.h" +#include "oneflow/core/common/tensor_meta.h" #include "oneflow/core/common/stride.h" #include "oneflow/core/framework/device.h" @@ -50,6 +50,36 @@ size_t LocalTensorMeta::CalcHashValue() const { ^ std::hash()(*device()) ^ std::hash()(stride()) ^ storage_offset(); } +MutLocalTensorMeta::MutLocalTensorMeta() + : TensorMeta(std::make_shared(), std::make_shared(), + kInvalidDataType), + device_(Symbol()), + storage_offset_(0) {} + +MutLocalTensorMeta::MutLocalTensorMeta(const std::shared_ptr& shape, DataType dtype, + Symbol device) + : TensorMeta(shape, std::make_shared(*shape), dtype), + device_(device), + storage_offset_(0) {} + +MutLocalTensorMeta::MutLocalTensorMeta(const std::shared_ptr& shape, + const std::shared_ptr& stride, DataType dtype, + Symbol device, int64_t storage_offset) + : TensorMeta(shape, stride, dtype), device_(device), storage_offset_(storage_offset) {} + +bool MutLocalTensorMeta::operator==(const MutLocalTensorMeta& other) const { + // It's correct to ignore is_dynamic_ field. + return *this->shape_ptr() == *other.shape_ptr() && this->dtype() == other.dtype() + && *this->device() == *other.device() && this->stride() == other.stride() + && this->storage_offset() == other.storage_offset(); +} + +size_t MutLocalTensorMeta::CalcHashValue() const { + // It's correct to ignore is_dynamic_ field. + return std::hash()(*shape_ptr()) ^ std::hash()(dtype()) + ^ std::hash()(*device()) ^ std::hash()(stride()) ^ storage_offset(); +} + bool GlobalTensorMeta::operator==(const GlobalTensorMeta& other) const { // It's correct to ignore is_dynamic_ field. return *this->shape_ptr() == *other.shape_ptr() && this->dtype() == other.dtype() diff --git a/oneflow/core/framework/tensor_meta.h b/oneflow/core/common/tensor_meta.h similarity index 81% rename from oneflow/core/framework/tensor_meta.h rename to oneflow/core/common/tensor_meta.h index df8bd82f815..fec696c9bb9 100644 --- a/oneflow/core/framework/tensor_meta.h +++ b/oneflow/core/common/tensor_meta.h @@ -13,11 +13,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifndef ONEFLOW_FRAMEWORK_TENSOR_META_H_ -#define ONEFLOW_FRAMEWORK_TENSOR_META_H_ +#ifndef ONEFLOW_COMMON_TENSOR_META_H_ +#define ONEFLOW_COMMON_TENSOR_META_H_ #include -#include "oneflow/core/framework/tensor_desc.h" +#include "oneflow/core/common/tensor_desc.h" #include "oneflow/core/common/symbol.h" namespace oneflow { @@ -102,6 +102,32 @@ class LocalTensorMeta : public TensorMeta { int64_t storage_offset_; }; +class MutLocalTensorMeta : public TensorMeta { + public: + // uninitialized MutLocalTensorMeta. + MutLocalTensorMeta(); + MutLocalTensorMeta(const MutLocalTensorMeta&) = default; + MutLocalTensorMeta(const std::shared_ptr& shape, DataType dtype, + Symbol device); + MutLocalTensorMeta(const std::shared_ptr& shape, + const std::shared_ptr& stride, DataType dtype, + Symbol device, int64_t storage_offset); + virtual ~MutLocalTensorMeta() = default; + + const Symbol& device() const { return device_; } + int64_t storage_offset() const { return storage_offset_; } + + Symbol* mut_device() { return &device_; } + void set_storage_offset(int64_t offset) { storage_offset_ = offset; } + + bool operator==(const MutLocalTensorMeta& other) const; + size_t CalcHashValue() const; + + private: + Symbol device_; + int64_t storage_offset_; +}; + class GlobalTensorMeta : public TensorMeta { public: GlobalTensorMeta(const std::shared_ptr& shape, DataType dtype, Symbol nd_sbp, @@ -148,4 +174,4 @@ struct hash final { } // namespace std -#endif // ONEFLOW_FRAMEWORK_TENSOR_META_H_ +#endif // ONEFLOW_COMMON_TENSOR_META_H_ diff --git a/oneflow/core/eager/eager_blob_object.cpp b/oneflow/core/eager/eager_blob_object.cpp index 2d7a1b8bb20..4d71526a95b 100644 --- a/oneflow/core/eager/eager_blob_object.cpp +++ b/oneflow/core/eager/eager_blob_object.cpp @@ -18,31 +18,77 @@ limitations under the License. #include "oneflow/core/framework/to_string.h" #include "oneflow/core/framework/shut_down_util.h" #include "oneflow/core/common/shape_vec.h" +#include "oneflow/core/common/tensor_meta.h" namespace oneflow { + namespace vm { -EagerBlobObject::EagerBlobObject(const std::shared_ptr& mem_case, - const std::shared_ptr& shape, - const std::shared_ptr& stride, DataType data_type, - const std::shared_ptr& tensor_storage, - const intrusive::shared_ptr& dep_object) +EagerBlobObject::EagerBlobObject( + const std::shared_ptr& mem_case, + const Symbol& local_tensor_meta, + const std::shared_ptr& mut_local_tensor_meta, DataType data_type, + const std::shared_ptr& tensor_storage, + const intrusive::shared_ptr& dep_object) : is_dynamic_(false), mem_case_(mem_case), data_type_(data_type), - shape_(shape), - stride_(stride), storage_offset_(0), tensor_storage_(tensor_storage), mem_ptr_for_allocation_compuation_pipelining_(nullptr), inited_mem_ptr_for_allocation_compuation_pipelining_(false), compute_local_dep_object_(dep_object), - blob_desc_(shape, stride, data_type) { - CHECK(static_cast(shape)); - CHECK(static_cast(stride)); + blob_desc_(static_cast(mut_local_tensor_meta) + ? std::const_pointer_cast(mut_local_tensor_meta->shape_ptr()) + : std::const_pointer_cast(local_tensor_meta->shape_ptr()), + static_cast(mut_local_tensor_meta) + ? std::const_pointer_cast(mut_local_tensor_meta->stride_ptr()) + : std::const_pointer_cast(local_tensor_meta->stride_ptr()), + data_type), + local_tensor_meta_(local_tensor_meta), + mut_local_tensor_meta_(mut_local_tensor_meta) { CHECK(static_cast(tensor_storage)); } +// user_op::TensorDesc overrides +const Shape& EagerBlobObject::shape() const { + if (mut_local_tensor_meta_) { + return mut_local_tensor_meta_->shape(); + } else { + return local_tensor_meta_->shape(); + } +} +Shape* EagerBlobObject::mut_shape() { + CHECK(mut_local_tensor_meta_); + return std::const_pointer_cast(mut_local_tensor_meta_)->mut_shape(); +} +const Stride& EagerBlobObject::stride() const { + if (mut_local_tensor_meta_) { + return mut_local_tensor_meta_->stride(); + } else { + return local_tensor_meta_->stride(); + } +} +Stride* EagerBlobObject::mut_stride() { + CHECK(mut_local_tensor_meta_); + return std::const_pointer_cast(mut_local_tensor_meta_)->mut_stride(); +} + +std::shared_ptr EagerBlobObject::shape_ptr() const { + if (mut_local_tensor_meta_) { + return mut_local_tensor_meta_->shape_ptr(); + } else { + return local_tensor_meta_->shape_ptr(); + } +} +std::shared_ptr EagerBlobObject::stride_ptr() const { + if (mut_local_tensor_meta_) { + return mut_local_tensor_meta_->stride_ptr(); + } else { + return local_tensor_meta_->stride_ptr(); + } +} + Blob* EagerBlobObject::blob() { if (!blob_) { blob_.reset(new Blob(*mem_case_, &blob_desc_, mut_header_ptr(), mut_dptr())); diff --git a/oneflow/core/eager/eager_blob_object.h b/oneflow/core/eager/eager_blob_object.h index b643c2deff9..de838200bbe 100644 --- a/oneflow/core/eager/eager_blob_object.h +++ b/oneflow/core/eager/eager_blob_object.h @@ -26,11 +26,18 @@ limitations under the License. #include "oneflow/core/framework/stream.h" #include "oneflow/core/framework/tensor_methods.h" #include "oneflow/core/framework/user_op_tensor.h" -#include "oneflow/core/framework/tensor_desc.h" +#include "oneflow/core/common/tensor_desc.h" #include "oneflow/core/register/blob.h" namespace oneflow { +namespace one { + +class LocalTensorMeta; +class MutLocalTensorMeta; + +} // namespace one + namespace vm { class TensorStorage { @@ -91,23 +98,31 @@ class EagerBlobObject final : public user_op::Tensor, public: EagerBlobObject(const EagerBlobObject&) = delete; EagerBlobObject(EagerBlobObject&&) = delete; - EagerBlobObject(const std::shared_ptr& mem_case, const std::shared_ptr& shape, - const std::shared_ptr& stride, DataType data_type, - const std::shared_ptr& tensor_storage) - : EagerBlobObject(mem_case, shape, stride, data_type, tensor_storage, - intrusive::shared_ptr()) {} - EagerBlobObject(const std::shared_ptr& mem_case, const std::shared_ptr& shape, - const std::shared_ptr& stride, DataType data_type, - const std::shared_ptr& tensor_storage, + EagerBlobObject(const std::shared_ptr& mem_case, + const Symbol& local_tensor_meta, + const std::shared_ptr& mut_local_tensor_meta, + DataType data_type, const std::shared_ptr& tensor_storage) + : EagerBlobObject(mem_case, local_tensor_meta, mut_local_tensor_meta, data_type, + tensor_storage, intrusive::shared_ptr()) {} + EagerBlobObject(const std::shared_ptr& mem_case, + const Symbol& local_tensor_meta, + const std::shared_ptr& mut_local_tensor_meta, + DataType data_type, const std::shared_ptr& tensor_storage, const intrusive::shared_ptr& dep_object); ~EagerBlobObject() { tensor_storage_.reset(); } + const std::shared_ptr& mut_tensor_meta() { + return mut_local_tensor_meta_; + } + // Getters + const Symbol& tensor_meta() const { return local_tensor_meta_; } + // user_op::TensorDesc overrides - const Shape& shape() const override { return *shape_; } - Shape* mut_shape() override { return shape_.get(); } - const Stride& stride() const override { return *stride_; } - Stride* mut_stride() override { return stride_.get(); } + const Shape& shape() const override; + Shape* mut_shape() override; + const Stride& stride() const override; + Stride* mut_stride() override; DataType data_type() const override { return data_type_; } DataType* mut_data_type() override { return &data_type_; } bool is_dynamic() const override { return is_dynamic_; } @@ -115,8 +130,8 @@ class EagerBlobObject final : public user_op::Tensor, void set_is_dynamic(bool is_dynamic) override { is_dynamic_ = is_dynamic; } // user_op::Tensor overrides - ShapeView shape_view() const override { return *shape_; } - MutShapeView mut_shape_view() override { return *shape_; } + ShapeView shape_view() const override { return shape(); } + MutShapeView mut_shape_view() override { return *mut_shape(); } const MemoryCase& mem_case() const override { return *mem_case_; } const void* raw_dptr() const override { CHECK(inited_mem_ptr_for_allocation_compuation_pipelining_) @@ -164,10 +179,10 @@ class EagerBlobObject final : public user_op::Tensor, tensor_storage_->set_last_used_stream(last_used_stream); } - std::shared_ptr shape_ptr() const { return shape_; } - std::shared_ptr stride_ptr() const { return stride_; } + std::shared_ptr shape_ptr() const; + std::shared_ptr stride_ptr() const; - size_t ByteSizeOfBlobBody() const { return shape_->elem_cnt() * GetSizeOfDataType(data_type_); } + size_t ByteSizeOfBlobBody() const { return shape().elem_cnt() * GetSizeOfDataType(data_type_); } size_t AlignedByteSizeOfBlobBody() const { return RoundUp(ByteSizeOfBlobBody(), kBlobBodyAlignSize); } @@ -176,8 +191,10 @@ class EagerBlobObject final : public user_op::Tensor, return RoundUp(ByteSizeOfBlobHeader(), kBlobHeaderAlignSize); } - const char* header_ptr() const { return reinterpret_cast(shape_->dim_vec().data()); } - char* mut_header_ptr() { return reinterpret_cast(shape_->dim_vec().data()); } + const char* header_ptr() const { return reinterpret_cast(shape().dim_vec().data()); } + char* mut_header_ptr() { + return reinterpret_cast(const_cast(shape().dim_vec().data())); + } void InitOrCheckMemPtrForAllocationComputationPipelining() { auto* ptr = tensor_storage_->blob_dptr(); @@ -201,8 +218,6 @@ class EagerBlobObject final : public user_op::Tensor, bool is_dynamic_; std::shared_ptr mem_case_; DataType data_type_; - std::shared_ptr shape_; - std::shared_ptr stride_; int64_t storage_offset_; std::shared_ptr tensor_storage_; // For allocation-computation pipeline, the value of mem_ptr_for_allocation_compuation_pipelining_ @@ -215,6 +230,8 @@ class EagerBlobObject final : public user_op::Tensor, // NOTE: Will be removed soon. Avoid to use it whenever possible. BlobDesc blob_desc_; std::unique_ptr blob_; + Symbol local_tensor_meta_; + std::shared_ptr mut_local_tensor_meta_; }; using EagerBlobObjectList = small_vector, kOpArgsReservedSize>; diff --git a/oneflow/core/eager/op_call_phy_instr_operand.cpp b/oneflow/core/eager/op_call_phy_instr_operand.cpp index 11272e7f486..404adaba251 100644 --- a/oneflow/core/eager/op_call_phy_instr_operand.cpp +++ b/oneflow/core/eager/op_call_phy_instr_operand.cpp @@ -47,7 +47,6 @@ OpCallPhyInstrOperand::OpCallPhyInstrOperand( } Maybe OpCallPhyInstrOperand::Init() { - OF_PROFILER_RANGE_GUARD("OpCallPhyInstrOperand::Init"); return mut_opkernel()->ChooseOpKernel(&call_ctx_, &user_opkernel_, &need_temp_storage_); } diff --git a/oneflow/core/framework/consistency_check.h b/oneflow/core/framework/consistency_check.h index 10934b5ba11..3729a63fb19 100644 --- a/oneflow/core/framework/consistency_check.h +++ b/oneflow/core/framework/consistency_check.h @@ -20,7 +20,7 @@ limitations under the License. #include "oneflow/core/common/symbol.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/framework/nd_sbp.h" -#include "oneflow/core/framework/tensor_meta.h" +#include "oneflow/core/common/tensor_meta.h" namespace oneflow { diff --git a/oneflow/core/framework/framework.h b/oneflow/core/framework/framework.h index cb62c928131..c84a06b7b4a 100644 --- a/oneflow/core/framework/framework.h +++ b/oneflow/core/framework/framework.h @@ -26,7 +26,7 @@ limitations under the License. #include "oneflow/core/framework/infer_nd_sbp_fn_context.h" #include "oneflow/core/framework/user_op_hob.h" -#include "oneflow/core/framework/tensor_desc.h" +#include "oneflow/core/common/tensor_desc.h" #include "oneflow/core/framework/op_kernel.h" #include "oneflow/core/framework/user_op_def.h" #include "oneflow/core/framework/multi_thread.h" diff --git a/oneflow/core/framework/global_tensor_infer_cache.h b/oneflow/core/framework/global_tensor_infer_cache.h index f2104100009..3a5a05ff987 100644 --- a/oneflow/core/framework/global_tensor_infer_cache.h +++ b/oneflow/core/framework/global_tensor_infer_cache.h @@ -22,7 +22,7 @@ limitations under the License. #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/framework/stream.h" -#include "oneflow/core/framework/tensor_meta.h" +#include "oneflow/core/common/tensor_meta.h" #include "oneflow/core/register/blob_desc.h" #include "oneflow/core/job/nd_sbp_infer_hint.h" diff --git a/oneflow/core/framework/infer_util.h b/oneflow/core/framework/infer_util.h index 5b32ea31844..1fcb07a7590 100644 --- a/oneflow/core/framework/infer_util.h +++ b/oneflow/core/framework/infer_util.h @@ -18,7 +18,7 @@ limitations under the License. #include "oneflow/core/common/maybe.h" #include "oneflow/core/framework/user_op_conf.h" -#include "oneflow/core/framework/tensor_desc.h" +#include "oneflow/core/common/tensor_desc.h" #include "oneflow/core/framework/attr_value.h" #include "oneflow/core/job/placement.pb.h" #include "oneflow/core/job/sbp_parallel.h" diff --git a/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp b/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp index 8bd8759ecdc..248a8589ceb 100644 --- a/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp +++ b/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp @@ -86,25 +86,32 @@ Maybe NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& in const auto& output_tensor_metas = result->output_tensor_metas(); vm::EagerBlobObjectList output_eager_blob_objects(outputs->size()); + const auto& kernel = JUST(user_op_expr.MutKernel4Stream(result->stream())); + for (int i = 0; i < outputs->size(); i++) { if (!outputs->at(i)) { // NOTE: if op support stride(non-contiguous input), then output tensor's stride // should be inferred in InferLogicalTensorDesc. // otherwise, it will be set here(according to shape). - // Note: symbol.shared_from_symbol() cannot be used here because set_stride happens in the - // next step. - std::shared_ptr tensor_impl = std::make_shared( - std::make_shared(*output_tensor_metas.at(i)), false, false); - if (!JUST(user_op_expr.SupportNonContiguous())) { - std::shared_ptr stride(new Stride(*tensor_impl->shape())); - tensor_impl->mut_tensor_meta()->set_stride(stride); + std::shared_ptr mut_tensor_meta; + { + if (kernel->output_is_mut2_type(i)) { + mut_tensor_meta = std::make_shared( + std::make_shared(output_tensor_metas.at(i)->shape()), + std::make_shared(output_tensor_metas.at(i)->stride()), + output_tensor_metas.at(i)->dtype(), output_tensor_metas.at(i)->device(), + output_tensor_metas.at(i)->storage_offset()); + } } + std::shared_ptr tensor_impl = + std::make_shared(false, false); const auto& dep_object = NewLocalDepObject(); - JUST(tensor_impl->InitEagerBlobObject(dep_object)); + JUST( + tensor_impl->InitEagerBlobObject(output_tensor_metas.at(i), mut_tensor_meta, dep_object)); output_eager_blob_objects.at(i) = JUST(tensor_impl->eager_blob_object()); (*outputs)[i] = std::make_shared(tensor_impl); } else { - auto* tensor_impl = JUST(TensorImpl4Tensor(outputs->at(i))); + const auto* tensor_impl = JUST(TensorImpl4Tensor(outputs->at(i))); // output i is inplaced. // check TensorMeta of infer result and TensorMeta of output i. CHECK_OR_RETURN(tensor_impl->tensor_meta()->shape() == output_tensor_metas.at(i)->shape()); @@ -118,9 +125,6 @@ Maybe NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& in } } - const auto& kernel = JUST(user_op_expr.MutKernel4Stream(result->stream())); - - OF_PROFILER_RANGE_PUSH("PhysicalRun"); JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe { return builder->Call(kernel, std::move(input_eager_blob_objects), std::move(output_eager_blob_objects), ctx, result->stream()); @@ -133,8 +137,18 @@ Maybe NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& in tensor_impl, btb, [](uint64_t) {}, "const"); })); JUST(btb->WaitUntilCntEqualZero(VirtualMachine::GetPredicatorNoMoreInstructionsFinished())); + const auto& mut_tensor_meta = const_cast(tensor_impl)->mut_tensor_meta(); + Symbol new_tensor_meta = SymbolOf(LocalTensorMeta( + std::make_shared(mut_tensor_meta->shape()), + std::make_shared(mut_tensor_meta->stride()), mut_tensor_meta->dtype(), + mut_tensor_meta->device(), mut_tensor_meta->storage_offset())); + std::shared_ptr final_tensor_impl = + std::make_shared(JUST(tensor_impl->tensor_storage()), false, false); + JUST(final_tensor_impl->InitEagerBlobObject( + new_tensor_meta, + JUST(JUST(outputs->at(index)->eager_blob_object())->compute_local_dep_object()))); + JUST(JUST(outputs->at(index)->AsLocalTensor())->set_impl(final_tensor_impl)); } - OF_PROFILER_RANGE_POP(); return Maybe::Ok(); } diff --git a/oneflow/core/framework/placement_sbp_util.cpp b/oneflow/core/framework/placement_sbp_util.cpp index 5bbae902e29..de3e01031c0 100644 --- a/oneflow/core/framework/placement_sbp_util.cpp +++ b/oneflow/core/framework/placement_sbp_util.cpp @@ -17,7 +17,7 @@ limitations under the License. #include #include "oneflow/core/framework/placement_sbp_util.h" #include "oneflow/core/framework/placed_nd_sbp.h" -#include "oneflow/core/framework/tensor_meta.h" +#include "oneflow/core/common/tensor_meta.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/common/shape.h" #include "oneflow/core/common/util.h" diff --git a/oneflow/core/framework/placement_sbp_util_test.cpp b/oneflow/core/framework/placement_sbp_util_test.cpp index 4bb1fbd876d..e02302063d9 100644 --- a/oneflow/core/framework/placement_sbp_util_test.cpp +++ b/oneflow/core/framework/placement_sbp_util_test.cpp @@ -15,7 +15,7 @@ limitations under the License. */ #include "gtest/gtest.h" #include "oneflow/core/framework/placement_sbp_util.h" -#include "oneflow/core/framework/tensor_meta.h" +#include "oneflow/core/common/tensor_meta.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/common/util.h" #include "oneflow/core/common/shape.h" diff --git a/oneflow/core/framework/sync_symbol_global_tensor_meta.cpp b/oneflow/core/framework/sync_symbol_global_tensor_meta.cpp index 3eaeabf08ba..cea91be4d1c 100644 --- a/oneflow/core/framework/sync_symbol_global_tensor_meta.cpp +++ b/oneflow/core/framework/sync_symbol_global_tensor_meta.cpp @@ -17,7 +17,7 @@ limitations under the License. #include "oneflow/core/framework/sync_symbol_parallel_desc.h" #include "oneflow/core/framework/sync_symbol_nd_sbp.h" #include "oneflow/core/framework/rank_group_rpc_util.h" -#include "oneflow/core/framework/tensor_meta.h" +#include "oneflow/core/common/tensor_meta.h" #include "oneflow/core/framework/synced_symbol_map.h" #include "oneflow/core/common/flat_shape.h" diff --git a/oneflow/core/framework/tensor.cpp b/oneflow/core/framework/tensor.cpp index 34ad15e31be..f3949c27204 100644 --- a/oneflow/core/framework/tensor.cpp +++ b/oneflow/core/framework/tensor.cpp @@ -65,12 +65,14 @@ std::shared_ptr Parameter::pin_memory() const { const Symbol& device, bool is_lazy, bool requires_grad, bool is_leaf) { const auto& tensor_meta = - std::make_shared(std::make_shared(*shape), dtype, device); + SymbolOf(LocalTensorMeta(std::make_shared(*shape), dtype, device)); if (is_lazy) { const auto& impl = std::make_shared(tensor_meta, requires_grad, is_leaf); return std::make_shared(impl); } else { - const auto& impl = std::make_shared(tensor_meta, requires_grad, is_leaf); + const auto& impl = std::make_shared(requires_grad, is_leaf); + const auto& dep_object = NewLocalDepObject(); + impl->InitEagerBlobObject(tensor_meta, dep_object); return std::make_shared(impl); } } diff --git a/oneflow/core/framework/tensor.h b/oneflow/core/framework/tensor.h index c70adbf07a0..b21bbaf8332 100644 --- a/oneflow/core/framework/tensor.h +++ b/oneflow/core/framework/tensor.h @@ -64,6 +64,7 @@ class Tensor : public std::enable_shared_from_this { virtual const TensorMeta& tensor_meta() const = 0; virtual Maybe data() = 0; virtual std::shared_ptr pin_memory() const = 0; + virtual Maybe> local_tensor_meta() const { OF_UNIMPLEMENTED(); } virtual Maybe> global_tensor_meta() const { OF_UNIMPLEMENTED(); } // Getters valid only for EagerLocalTensor @@ -164,6 +165,9 @@ class StaticZerosTensor final : public Tensor { std::shared_ptr pin_memory() const override { return std::const_pointer_cast(shared_from_this()); } + Maybe> local_tensor_meta() const override { + RETURN_ERROR_WITH_BUG_PROMPT(); + } Maybe> global_tensor_meta() const override { RETURN_ERROR_WITH_BUG_PROMPT(); } @@ -315,6 +319,9 @@ class ProxyTensor : public TensorIf { virtual bool is_lazy() const override { return tensor_->is_lazy(); } virtual bool is_eager() const override { return tensor_->is_eager(); } virtual const TensorMeta& tensor_meta() const override { return tensor_->tensor_meta(); } + virtual Maybe> local_tensor_meta() const override { + return tensor_->local_tensor_meta(); + } virtual Maybe> global_tensor_meta() const override { return tensor_->global_tensor_meta(); } @@ -496,6 +503,8 @@ class LocalTensor final : public TensorIf { bool is_contiguous() const override { return impl_->is_contiguous(); } Maybe is_pinned() const override { return impl_->is_pinned(); }; + Maybe> local_tensor_meta() const override { return impl_->tensor_meta(); } + // Setters for autograd Maybe set_acc_grad(const std::shared_ptr& grad) override { return impl_->set_acc_grad(grad); @@ -530,9 +539,16 @@ class LocalTensor final : public TensorIf { Maybe mut_eager_local_tensor_impl() override { return impl_->mut_eager_local_tensor_impl(); } - user_op::TensorDesc* mut_tensor_meta() override { return impl_->mut_tensor_meta(); } + user_op::TensorDesc* mut_tensor_meta() override { + return std::const_pointer_cast(impl_->mut_tensor_meta()).get(); + } Maybe set_data(const std::shared_ptr& other) override; + Maybe set_impl(std::shared_ptr impl) { + impl_ = impl; + return Maybe::Ok(); + } + Maybe RegisterStorageDeleteHook(const std::function& hook) override { return impl_->RegisterStorageDeleteHook(hook); } diff --git a/oneflow/core/framework/tensor_impl.cpp b/oneflow/core/framework/tensor_impl.cpp index f4d6ea92859..9999604dfeb 100644 --- a/oneflow/core/framework/tensor_impl.cpp +++ b/oneflow/core/framework/tensor_impl.cpp @@ -16,7 +16,7 @@ limitations under the License. #include #include "oneflow/core/common/blocking_then_busy.h" #include "oneflow/core/common/stream_role.h" -#include "oneflow/core/framework/tensor_meta.h" +#include "oneflow/core/common/tensor_meta.h" #include "oneflow/core/vm/virtual_machine.h" #include "oneflow/core/framework/instructions_builder.h" #include "oneflow/core/framework/tensor_impl.h" @@ -68,20 +68,14 @@ Maybe LazyLocalTensorImpl::detach() const { return std::shared_ptr(detached_impl); } -EagerLocalTensorImpl::EagerLocalTensorImpl() - : LocalTensorImpl(std::make_shared(), false, false) {} +EagerLocalTensorImpl::EagerLocalTensorImpl() : LocalTensorImpl(false, false) {} -EagerLocalTensorImpl::EagerLocalTensorImpl( - const std::shared_ptr& tensor_meta, bool requires_grad, bool is_leaf) - : LocalTensorImpl(tensor_meta, requires_grad, is_leaf) {} +EagerLocalTensorImpl::EagerLocalTensorImpl(const std::shared_ptr& tensor_storage, + bool requires_grad, bool is_leaf) + : LocalTensorImpl(requires_grad, is_leaf), tensor_storage_(tensor_storage) {} EagerLocalTensorImpl::~EagerLocalTensorImpl() {} -EagerLocalTensorImpl::EagerLocalTensorImpl( - const std::shared_ptr& tensor_meta, - const std::shared_ptr& tensor_storage, bool requires_grad, bool is_leaf) - : LocalTensorImpl(tensor_meta, requires_grad, is_leaf), tensor_storage_(tensor_storage) {} - Maybe EagerLocalTensorImpl::UpdateTensorStorage() { const auto& eager_blob_object = eager_blob_object_; tensor_storage_ = std::make_shared(eager_blob_object->tensor_storage()); @@ -97,25 +91,34 @@ Maybe EagerLocalTensorImpl::UpdateTensorStorage() { return Maybe::Ok(); } +const std::shared_ptr& EagerLocalTensorImpl::mut_tensor_meta() { + return eager_blob_object_->mut_tensor_meta(); +} +// Getters +const Symbol& EagerLocalTensorImpl::tensor_meta() const { + return eager_blob_object_->tensor_meta(); +} + Maybe EagerLocalTensorImpl::compute_local_dep_object() const { return JUST(eager_blob_object())->compute_local_dep_object(); } Maybe EagerLocalTensorImpl::InitEagerBlobObject( + const Symbol& local_tensor_meta, + const std::shared_ptr& mut_local_tensor_meta, const intrusive::shared_ptr& dep_object) { - CHECK_OR_RETURN(static_cast(device())); - const auto& mem_case = device()->mem_case(); - const auto& mut_shape = std::const_pointer_cast(tensor_meta()->shape_ptr()); - const auto& mut_stride = std::const_pointer_cast(tensor_meta()->stride_ptr()); + CHECK_OR_RETURN(static_cast(local_tensor_meta->device())); + const auto& mem_case = local_tensor_meta->device()->mem_case(); if (tensor_storage_) { auto tensor_storage = tensor_storage_->storage(); - eager_blob_object_ = std::make_shared(mem_case, mut_shape, mut_stride, - dtype(), tensor_storage, dep_object); + eager_blob_object_ = std::make_shared( + mem_case, local_tensor_meta, mut_local_tensor_meta, local_tensor_meta->dtype(), + tensor_storage, dep_object); } else { - const auto& eager_blob_object = - std::make_shared(mem_case, mut_shape, mut_stride, dtype(), - std::make_shared(), dep_object); + const auto& eager_blob_object = std::make_shared( + mem_case, local_tensor_meta, mut_local_tensor_meta, local_tensor_meta->dtype(), + std::make_shared(), dep_object); JUST(set_eager_blob_object(eager_blob_object)); } return Maybe::Ok(); @@ -149,8 +152,7 @@ std::shared_ptr EagerLocalTensorImpl::stride() const { } Maybe EagerLocalTensorImpl::detach() const { - auto detached_impl = - std::make_shared(tensor_meta_, tensor_storage_, false, true); + auto detached_impl = std::make_shared(tensor_storage_, false, true); detached_impl->eager_blob_object_ = eager_blob_object_; return std::shared_ptr(detached_impl); } @@ -211,11 +213,10 @@ Maybe GetPhysicalShape(const Shape& logical_shape, const NdSbp& nd_sbp, // empty op. if (parallel_id.has_value() && shape->elem_cnt() != 0) { const auto& cur_rank_phy_tensor_meta = - std::make_shared(cur_rank_phy_shape, dtype, device); - auto cur_rank_phy_tensor_impl = - std::make_shared(cur_rank_phy_tensor_meta, requires_grad, is_leaf); + SymbolOf(LocalTensorMeta(cur_rank_phy_shape, dtype, device)); + auto cur_rank_phy_tensor_impl = std::make_shared(requires_grad, is_leaf); const auto& dep_object = NewLocalDepObject(); - JUST(cur_rank_phy_tensor_impl->InitEagerBlobObject(dep_object)); + JUST(cur_rank_phy_tensor_impl->InitEagerBlobObject(cur_rank_phy_tensor_meta, dep_object)); cur_rank_phy_tensor = std::make_shared(cur_rank_phy_tensor_impl); } else { const auto& dtype_symbol = JUST(DType::Get(dtype)); diff --git a/oneflow/core/framework/tensor_impl.h b/oneflow/core/framework/tensor_impl.h index 1e4ad7dba5d..a77a308db22 100644 --- a/oneflow/core/framework/tensor_impl.h +++ b/oneflow/core/framework/tensor_impl.h @@ -21,8 +21,8 @@ limitations under the License. #include "oneflow/core/common/data_type.h" #include "oneflow/core/common/optional.h" #include "oneflow/core/framework/tensor_storage.h" -#include "oneflow/core/framework/tensor_desc.h" -#include "oneflow/core/framework/tensor_meta.h" +#include "oneflow/core/common/tensor_desc.h" +#include "oneflow/core/common/tensor_meta.h" #include "oneflow/core/framework/transport_token.h" #include "oneflow/core/autograd/autograd_meta.h" #include "oneflow/core/common/symbol.h" @@ -105,14 +105,16 @@ class LocalTensorImpl : public TensorImpl { virtual ~LocalTensorImpl() = default; // Getters - DataType dtype() const override { return tensor_meta_->dtype(); } - const Symbol& device() const { return tensor_meta_->device(); } - const std::shared_ptr& tensor_meta() const { return tensor_meta_; } - bool is_contiguous() const override { return tensor_meta_->is_contiguous(); } + DataType dtype() const override { return tensor_meta()->dtype(); } + const Symbol& device() const { return tensor_meta()->device(); } + bool is_contiguous() const override { return tensor_meta()->is_contiguous(); } + virtual const Symbol& tensor_meta() const = 0; // Setters - LocalTensorMeta* mut_tensor_meta() { return const_cast(tensor_meta_.get()); } - Maybe*> mut_device() { return mut_tensor_meta()->mut_device(); } + virtual const std::shared_ptr& mut_tensor_meta() = 0; + Maybe*> mut_device() { + return std::const_pointer_cast(mut_tensor_meta())->mut_device(); + } virtual Maybe mut_eager_local_tensor_impl() { RETURN_ERROR_WITH_BUG_PROMPT(); } @@ -120,11 +122,7 @@ class LocalTensorImpl : public TensorImpl { virtual Maybe detach() const { RETURN_ERROR_WITH_BUG_PROMPT(); } protected: - LocalTensorImpl(const std::shared_ptr& tensor_meta, bool requires_grad, - bool is_leaf) - : TensorImpl(requires_grad, is_leaf), tensor_meta_(tensor_meta) {} - - std::shared_ptr tensor_meta_; + LocalTensorImpl(bool requires_grad, bool is_leaf) : TensorImpl(requires_grad, is_leaf) {} }; class LocalTensor; @@ -186,12 +184,12 @@ class GlobalTensorImpl : public TensorImpl { class LazyLocalTensorImpl final : public LocalTensorImpl { public: OF_DISALLOW_COPY_AND_MOVE(LazyLocalTensorImpl); - LazyLocalTensorImpl(const std::shared_ptr& tensor_meta, bool requires_grad, - bool is_leaf) - : LocalTensorImpl(tensor_meta, requires_grad, is_leaf) {} + LazyLocalTensorImpl(const Symbol& tensor_meta, bool requires_grad, bool is_leaf) + : LocalTensorImpl(requires_grad, is_leaf), tensor_meta_(tensor_meta) {} ~LazyLocalTensorImpl() override = default; // Getters + const Symbol& tensor_meta() const override { return tensor_meta_; } std::shared_ptr shape() const override { return tensor_meta()->shape_ptr(); } std::shared_ptr stride() const override { return tensor_meta()->stride_ptr(); } bool is_lazy() const override { return true; } @@ -202,6 +200,10 @@ class LazyLocalTensorImpl final : public LocalTensorImpl { } Maybe is_pinned() const override { return false; } + const std::shared_ptr& mut_tensor_meta() override { + PRINT_BUG_PROMPT_AND_ABORT(); + } + // Getters valid only for EagerLocalTensorImpl Maybe eager_blob_object() const override { RETURN_ERROR_WITH_BUG_PROMPT(); } Maybe compute_local_dep_object() const override { @@ -210,25 +212,30 @@ class LazyLocalTensorImpl final : public LocalTensorImpl { Maybe tensor_storage() const override { RETURN_ERROR_WITH_BUG_PROMPT(); } Maybe has_eager_blob_object() const override { RETURN_ERROR_WITH_BUG_PROMPT(); } Maybe detach() const override; + + private: + Symbol tensor_meta_; }; class EagerLocalTensorImpl final : public LocalTensorImpl { public: OF_DISALLOW_COPY_AND_MOVE(EagerLocalTensorImpl); EagerLocalTensorImpl(); - EagerLocalTensorImpl(const std::shared_ptr& tensor_meta, - bool requires_grad, bool is_leaf); - EagerLocalTensorImpl(const std::shared_ptr& tensor_meta, - const std::shared_ptr& tensor_storage, bool requires_grad, + EagerLocalTensorImpl(const std::shared_ptr& tensor_storage, bool requires_grad, bool is_leaf); + + EagerLocalTensorImpl(bool requires_grad, bool is_leaf) + : EagerLocalTensorImpl(std::shared_ptr(), requires_grad, is_leaf) {} ~EagerLocalTensorImpl() override; + const std::shared_ptr& mut_tensor_meta() override; // Getters + const Symbol& tensor_meta() const override; std::shared_ptr shape() const override; std::shared_ptr stride() const override; Maybe detach() const override; bool is_lazy() const override { return false; } - bool is_contiguous() const override { return tensor_meta_->is_contiguous(); } + bool is_contiguous() const override { return tensor_meta()->is_contiguous(); } Maybe is_pinned() const override; // Getters valid only for EagerLocalTensorImpl @@ -242,12 +249,21 @@ class EagerLocalTensorImpl final : public LocalTensorImpl { return tensor_storage_; } Maybe has_eager_blob_object() const override { return eager_blob_object_.get(); } - Maybe storage_offset() const override { return tensor_meta_->storage_offset(); } - + Maybe storage_offset() const override { return tensor_meta()->storage_offset(); } // Setters TensorStorage* mut_tensor_storage() { return tensor_storage_.get(); } - Maybe InitEagerBlobObject(const intrusive::shared_ptr& dep_object); + Maybe InitEagerBlobObject( + const Symbol& local_tensor_meta, + const std::shared_ptr& mut_local_tensor_meta, + const intrusive::shared_ptr& dep_object); + Maybe InitEagerBlobObject(const Symbol& local_tensor_meta, + const intrusive::shared_ptr& dep_object) { + JUST(InitEagerBlobObject(local_tensor_meta, std::shared_ptr(), + dep_object)); + return Maybe::Ok(); + } + Maybe mut_eager_local_tensor_impl() override { return this; } Maybe RegisterStorageDeleteHook(const std::function& hook) override; diff --git a/oneflow/core/framework/tensor_methods.cpp b/oneflow/core/framework/tensor_methods.cpp index 8d3ebc842ad..cfc4ddc287c 100644 --- a/oneflow/core/framework/tensor_methods.cpp +++ b/oneflow/core/framework/tensor_methods.cpp @@ -64,18 +64,19 @@ Maybe BasicView(const std::shared_ptr& input, const Shape& targe const Stride& target_stride, int64_t storage_offset) { // TODO(): Check shape compatible. auto device = JUST(input->device()); - auto tensor_meta = std::make_shared( - std::make_shared(target_shape), std::make_shared(target_stride), - input->dtype()->data_type(), device, storage_offset); + auto tensor_meta = SymbolOf(LocalTensorMeta(std::make_shared(target_shape), + std::make_shared(target_stride), + input->dtype()->data_type(), device, storage_offset)); CHECK_OR_RETURN(JUST(input->has_eager_blob_object())); // new output tensor const auto& blob_object = JUST(input->eager_blob_object()); bool requires_grad = (autograd::GradMode::is_enabled() && input->requires_grad()); - auto tensor_impl = std::make_shared( - tensor_meta, JUST(input->tensor_storage()), requires_grad, - /*is_leaf=*/!requires_grad); - JUST(tensor_impl->InitEagerBlobObject(JUST(blob_object->compute_local_dep_object()))); + auto tensor_impl = + std::make_shared(JUST(input->tensor_storage()), requires_grad, + /*is_leaf=*/!requires_grad); + JUST( + tensor_impl->InitEagerBlobObject(tensor_meta, JUST(blob_object->compute_local_dep_object()))); auto view_tensor = std::make_shared(tensor_impl); diff --git a/oneflow/core/framework/user_op_conf.h b/oneflow/core/framework/user_op_conf.h index 706ca3efcf9..69e62503ef5 100644 --- a/oneflow/core/framework/user_op_conf.h +++ b/oneflow/core/framework/user_op_conf.h @@ -18,7 +18,7 @@ limitations under the License. #include "oneflow/core/common/util.h" #include "oneflow/core/common/maybe.h" -#include "oneflow/core/framework/tensor_desc.h" +#include "oneflow/core/common/tensor_desc.h" #include "oneflow/core/framework/user_op_def.pb.h" #include "oneflow/core/framework/user_op_attr.pb.h" #include "oneflow/core/framework/user_op_conf.pb.h" diff --git a/oneflow/core/framework/user_op_registry_manager.cpp b/oneflow/core/framework/user_op_registry_manager.cpp index f573ee9442c..0760edd2483 100644 --- a/oneflow/core/framework/user_op_registry_manager.cpp +++ b/oneflow/core/framework/user_op_registry_manager.cpp @@ -17,7 +17,7 @@ limitations under the License. #include "oneflow/core/common/util.h" #include "oneflow/core/framework/infer_util.h" -#include "oneflow/core/framework/tensor_desc.h" +#include "oneflow/core/common/tensor_desc.h" #include "oneflow/core/kernel/kernel.pb.h" #include "oneflow/core/operator/operator.h" diff --git a/oneflow/core/functional/impl/array_functor.cpp b/oneflow/core/functional/impl/array_functor.cpp index 3f285aaf5fb..c1df6d71e72 100644 --- a/oneflow/core/functional/impl/array_functor.cpp +++ b/oneflow/core/functional/impl/array_functor.cpp @@ -1215,11 +1215,21 @@ class InplaceToContiguousFunctor { << "Both ref and value must be local tensor."; std::shared_ptr stride(new Stride(*input->shape())); // update stride - JUST(input->mut_eager_local_tensor_impl())->mut_tensor_meta()->set_stride(stride); const auto& blob_object = JUST(input->eager_blob_object()); - // update eager_blob_object - JUST(JUST(input->mut_eager_local_tensor_impl()) - ->InitEagerBlobObject(JUST(blob_object->compute_local_dep_object()))); + Symbol old_tensor_meta = JUST(input->local_tensor_meta()); + + Symbol new_tensor_meta = SymbolOf(LocalTensorMeta( + std::make_shared(old_tensor_meta->shape()), stride, old_tensor_meta->dtype(), + old_tensor_meta->device(), old_tensor_meta->storage_offset())); + + std::shared_ptr final_tensor_impl = + std::make_shared(JUST(input->tensor_storage()), + input->requires_grad(), input->is_leaf()); + final_tensor_impl->set_retain_grad(input->retain_grad()); + final_tensor_impl->InitEagerBlobObject(new_tensor_meta, + JUST(blob_object->compute_local_dep_object())); + JUST(JUST(input->AsLocalTensor())->set_impl(final_tensor_impl)); + // assign contiguous tensor data JUST(OpInterpUtil::Dispatch(*assign_op_, {input, contiguous_tensor})); return input; diff --git a/oneflow/core/operator/user_op.cpp b/oneflow/core/operator/user_op.cpp index e7e9d8c2d2f..c334877fe83 100644 --- a/oneflow/core/operator/user_op.cpp +++ b/oneflow/core/operator/user_op.cpp @@ -15,7 +15,7 @@ limitations under the License. */ #include "oneflow/core/framework/infer_util.h" #include "oneflow/core/framework/sbp_context.h" -#include "oneflow/core/framework/tensor_desc.h" +#include "oneflow/core/common/tensor_desc.h" #include "oneflow/core/framework/to_string.h" #include "oneflow/core/operator/user_op.h" #include "oneflow/core/framework/infer_output_blob_time_shape_fn_context.h" diff --git a/oneflow/user/kernels/stateful_opkernel.cpp b/oneflow/user/kernels/stateful_opkernel.cpp index 15fdec9b24a..e6953b84be4 100644 --- a/oneflow/user/kernels/stateful_opkernel.cpp +++ b/oneflow/user/kernels/stateful_opkernel.cpp @@ -174,7 +174,7 @@ class UserOpInferContextHelper final { const Shape& InputShape(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { - return *Shape4ArgNameAndIndex(call_ctx, arg_name, index); + return NonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index)->shape(); } Shape* OutputShape(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { @@ -186,7 +186,7 @@ class UserOpInferContextHelper final { } const Stride& InputStride(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { - return *Stride4ArgNameAndIndex(call_ctx, arg_name, index); + return NonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index)->stride(); } Stride* OutputStride(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { @@ -665,7 +665,8 @@ Maybe InitTensorTupleIndexes4Bns(const std::shared_ptr std::vector* input_tuple_indexes4const_ibns, std::vector* input_tuple_indexes4mut_ibns, std::vector* output_tuple_indexes4mut_obns, - std::vector* output_tuple_indexes4mut2_obns) { + std::vector* output_tuple_indexes4mut2_obns, + HashMap* output_tuple_indexe2is_mut2_type) { const auto* op_reg_val = user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(op_conf->user_conf().op_type_name()); CHECK_NOTNULL_OR_RETURN(op_reg_val); @@ -718,8 +719,10 @@ Maybe InitTensorTupleIndexes4Bns(const std::shared_ptr const std::string obn = GenRepeatedBn(pair.first, pair.second); if (arg_modifier_signature.obn2output_blob_modifier().at(obn).header_infered_before_compute()) { output_tuple_indexes4mut_obns->emplace_back(i); + output_tuple_indexe2is_mut2_type->emplace(i, false); } else { output_tuple_indexes4mut2_obns->emplace_back(i); + output_tuple_indexe2is_mut2_type->emplace(i, true); } } return Maybe::Ok(); @@ -766,7 +769,7 @@ Maybe InitTensorTupleIndexes4Bns(const std::shared_ptr op_conf, input_arg_tuple->indexed_arg_name_and_index(), output_arg_tuple->indexed_arg_name_and_index(), &opkernel->input_tuple_indexes4const_ibns_, &opkernel->input_tuple_indexes4mut_ibns_, &opkernel->output_tuple_indexes4mut_obns_, - &opkernel->output_tuple_indexes4mut2_obns_)); + &opkernel->output_tuple_indexes4mut2_obns_, &opkernel->output_tuple_indexe2is_mut2_type_)); return opkernel; } diff --git a/oneflow/user/kernels/stateful_opkernel.h b/oneflow/user/kernels/stateful_opkernel.h index cfcb0477de2..608d9bbc378 100644 --- a/oneflow/user/kernels/stateful_opkernel.h +++ b/oneflow/user/kernels/stateful_opkernel.h @@ -17,7 +17,7 @@ limitations under the License. #define ONEFLOW_USER_KERNELS_STATEFUL_OPKERNEL_H_ #include "oneflow/core/eager/eager_blob_object.h" -#include "oneflow/core/framework/tensor_meta.h" +#include "oneflow/core/common/tensor_meta.h" #include "oneflow/core/kernel/kernel.h" #include "oneflow/core/framework/op_kernel.h" #include "oneflow/core/framework/stream.h" @@ -71,6 +71,10 @@ class StatefulOpKernel final { return output_tuple_indexes4mut2_obns_; } + bool output_is_mut2_type(int64_t index) const { + return output_tuple_indexe2is_mut2_type_.at(index); + } + const AttrMap& base_attrs() const { return base_attrs_; } size_t InferTmpSize(eager::CallContext* call_ctx, const user_op::OpKernel* user_opkernel) const; @@ -126,6 +130,7 @@ class StatefulOpKernel final { std::vector input_tuple_indexes4mut_ibns_; std::vector output_tuple_indexes4mut_obns_; std::vector output_tuple_indexes4mut2_obns_; + HashMap output_tuple_indexe2is_mut2_type_; }; } // namespace one From ee9a52592040ff567116ca262f5f0393e93ded13 Mon Sep 17 00:00:00 2001 From: clackhan Date: Fri, 15 Jul 2022 12:37:38 +0800 Subject: [PATCH 17/67] check shape in critical_sectio --- .../core/eager/critical_section_phy_instr_operand.cpp | 5 ++--- oneflow/core/register/blob.h | 9 ++++++++- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/oneflow/core/eager/critical_section_phy_instr_operand.cpp b/oneflow/core/eager/critical_section_phy_instr_operand.cpp index 5e5d2637299..e0f3e68887f 100644 --- a/oneflow/core/eager/critical_section_phy_instr_operand.cpp +++ b/oneflow/core/eager/critical_section_phy_instr_operand.cpp @@ -70,8 +70,7 @@ void InputCriticalSectionBeginPhyInstrOperand::AccessBlobByOpName(uint64_t of_bl { size_t header_size = of_blob->mut_blob()->blob_desc().ByteSizeOfBlobHeader(); CHECK_EQ(header_size, eager_blob_object->shape().NumAxes() * sizeof(int64_t)); - std::memcpy(of_blob->mut_blob()->mut_header_ptr(), eager_blob_object->mut_header_ptr(), - header_size); + CHECK_EQ(of_blob->blob().static_shape(), eager_blob_object->shape()); } const auto& end_event_record = op_name2end_event_record_->at(op_name); if (eager_blob_object->dptr() == nullptr) { @@ -93,7 +92,7 @@ void OutputCriticalSectionBeginPhyInstrOperand::AccessBlobByOpName(uint64_t of_b CHECK(interfaces_valid().at(i)); OfBlob* of_blob = reinterpret_cast(of_blob_ptr); auto& eager_blob_object = eager_blob_objects_->at(i); - of_blob->blob().shape_view().ToShape(eager_blob_object->mut_shape()); + CHECK_EQ(of_blob->blob().static_shape(), eager_blob_object->shape()); const auto& end_event_record = op_name2end_event_record_->at(op_name); if (eager_blob_object->dptr() == nullptr) { end_event_record->Init(std::make_shared()); diff --git a/oneflow/core/register/blob.h b/oneflow/core/register/blob.h index bea1635a938..56a80b8cbfa 100644 --- a/oneflow/core/register/blob.h +++ b/oneflow/core/register/blob.h @@ -56,7 +56,12 @@ class Blob final { DataType data_type() const { return blob_desc_->data_type(); } const char* header_ptr() const { return header_ptr_; } - char* mut_header_ptr() { return header_ptr_; } + [[deprecated( + "\"mut_header_ptr\" will be removed in Bolb. Please avoid to use this method whenever " + "possible. Almost all methods of `mut_header_ptr` are also in `Blob`.")]] char* + mut_header_ptr() { + return header_ptr_; + } char* mut_contiguous_header_ptr(); const BlobDesc& blob_desc() const { return *blob_desc_; } const BlobDesc* blob_desc_ptr() const { return blob_desc_; } @@ -91,6 +96,7 @@ class Blob final { CheckDataType(data_type()); return static_cast(dptr_); } + // shape const Shape& static_shape() const { return blob_desc_->shape(); } const ShapeView& shape_view() const { return *shape_view_; } @@ -100,6 +106,7 @@ class Blob final { return mut_shape_view_.get(); } MutShapeView* ForceMutShapeView() { return mut_shape_view_.get(); } + // stride const Stride& stride() const { return blob_desc_->stride(); } From d2e16efb45bfe604d9bdf17aae7a5f171fe62aa7 Mon Sep 17 00:00:00 2001 From: clackhan Date: Mon, 18 Jul 2022 11:06:22 +0800 Subject: [PATCH 18/67] add kMaxNumDims --- oneflow/core/common/constant.h | 1 + 1 file changed, 1 insertion(+) diff --git a/oneflow/core/common/constant.h b/oneflow/core/common/constant.h index 3f8b331bdb4..7760e161128 100644 --- a/oneflow/core/common/constant.h +++ b/oneflow/core/common/constant.h @@ -24,6 +24,7 @@ static const int64_t kInvalidSessionId = -1; static const std::string kNoPassTag = ""; static const std::string kMainOp = "main_op"; static const int64_t kMaxSplitAxis = 6; +constexpr size_t kMaxNumDims = 8; static const std::string kAsymmetricCodeErrorMsg = "Maybe executing different code in different ranks, please check if the code is branched and " "operates on the global tensor."; From a395638c95ad1007dd7334062f136b1dde03aa5c Mon Sep 17 00:00:00 2001 From: clackhan Date: Mon, 18 Jul 2022 11:11:04 +0800 Subject: [PATCH 19/67] fix error include --- oneflow/core/framework/local_tensor_infer_cache.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/oneflow/core/framework/local_tensor_infer_cache.h b/oneflow/core/framework/local_tensor_infer_cache.h index 055a5edb4a1..58d745bd987 100644 --- a/oneflow/core/framework/local_tensor_infer_cache.h +++ b/oneflow/core/framework/local_tensor_infer_cache.h @@ -23,7 +23,7 @@ limitations under the License. #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/framework/stream.h" -#include "oneflow/core/common/tensor_meta.h" +#include "oneflow/core/framework/tensor_meta.h" namespace oneflow { From ff369afe054c7b12e3954dc8aa8e92bd2481a851 Mon Sep 17 00:00:00 2001 From: clackhan Date: Mon, 18 Jul 2022 11:25:16 +0800 Subject: [PATCH 20/67] fix split Symbol LocalTensorMeta error --- .../core/framework/local_tensor_infer_cache.cpp | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/oneflow/core/framework/local_tensor_infer_cache.cpp b/oneflow/core/framework/local_tensor_infer_cache.cpp index 013958d1f30..eabbd06ee36 100644 --- a/oneflow/core/framework/local_tensor_infer_cache.cpp +++ b/oneflow/core/framework/local_tensor_infer_cache.cpp @@ -58,7 +58,7 @@ class UserOpExprDeviceAndStreamInferContext final : public user_op::DeviceAndStr public: UserOpExprDeviceAndStreamInferContext(const UserOpExpr* user_op_expr, const LocalTensorMetaInferArgs* infer_args, - std::vector* output_tensor_metas) + std::vector* output_tensor_metas) : user_op_expr_(user_op_expr), composed_attrs_(infer_args->attrs(), user_op_expr->base_attrs()), in_tensor_devices_(user_op_expr_->input_size()), @@ -111,7 +111,7 @@ class UserOpExprDeviceAndStreamInferContext final : public user_op::DeviceAndStr Maybe> InferDeviceAndStream(const UserOpExpr& user_op_expr, const LocalTensorMetaInferArgs& infer_args, - std::vector* output_tensor_metas) { + std::vector* output_tensor_metas) { if (!user_op_expr.device_and_stream_infer_fn()) { Symbol device = infer_args.input_local_tensor_metas().at(0)->device(); return GetDefaultStreamByDevice(device); @@ -150,7 +150,10 @@ Maybe LocalTensorMetaInferArgs::Init(const AttrMap& attrs, Symbol Maybe LocalTensorMetaInferArgs::InitInputLocalTensorMetas(const TensorTuple& input_tensors) { for (int i = 0; i < input_tensors.size(); ++i) { - input_local_tensor_metas_.at(i) = JUST(input_tensors.at(i)->local_tensor_meta()); + LocalTensorMeta* local_tensor_meta = + dynamic_cast(input_tensors.at(i)->mut_tensor_meta()); + CHECK_NOTNULL_OR_RETURN(local_tensor_meta); + input_local_tensor_metas_.at(i) = SymbolOf(*local_tensor_meta); } return Maybe::Ok(); } @@ -163,7 +166,7 @@ Maybe LocalTensorMetaInferArgs::InitInputLocalTensorMetas(const TensorTupl auto result = std::make_unique(user_op_expr.output_size()); - std::vector output_mut_metas(user_op_expr.output_size()); + std::vector output_mut_metas(user_op_expr.output_size()); // Infer devices Symbol stream; if (!user_op_expr.has_device_and_stream_infer_fn()) { @@ -192,10 +195,7 @@ Maybe LocalTensorMetaInferArgs::InitInputLocalTensorMetas(const TensorTupl std::shared_ptr stride(new Stride(output_mut_metas.at(i).shape())); output_mut_metas.at(i).set_stride(stride); } - output_metas->at(i) = SymbolOf( - LocalTensorMeta(output_mut_metas.at(i).shape_ptr(), output_mut_metas.at(i).stride_ptr(), - output_mut_metas.at(i).data_type(), output_mut_metas.at(i).device(), - output_mut_metas.at(i).storage_offset())); + output_metas->at(i) = SymbolOf(output_mut_metas.at(i)); } return std::shared_ptr(std::move(result)); } From 83c68b76661c6eed2f190ae3b2845e53f4ff7012 Mon Sep 17 00:00:00 2001 From: clackhan Date: Mon, 18 Jul 2022 11:59:13 +0800 Subject: [PATCH 21/67] fix split cache and symbolic local tensor meta error --- .../core/framework/local_tensor_infer_cache.cpp | 16 ++++++++-------- .../core/framework/local_tensor_infer_cache.h | 2 +- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/oneflow/core/framework/local_tensor_infer_cache.cpp b/oneflow/core/framework/local_tensor_infer_cache.cpp index eabbd06ee36..013958d1f30 100644 --- a/oneflow/core/framework/local_tensor_infer_cache.cpp +++ b/oneflow/core/framework/local_tensor_infer_cache.cpp @@ -58,7 +58,7 @@ class UserOpExprDeviceAndStreamInferContext final : public user_op::DeviceAndStr public: UserOpExprDeviceAndStreamInferContext(const UserOpExpr* user_op_expr, const LocalTensorMetaInferArgs* infer_args, - std::vector* output_tensor_metas) + std::vector* output_tensor_metas) : user_op_expr_(user_op_expr), composed_attrs_(infer_args->attrs(), user_op_expr->base_attrs()), in_tensor_devices_(user_op_expr_->input_size()), @@ -111,7 +111,7 @@ class UserOpExprDeviceAndStreamInferContext final : public user_op::DeviceAndStr Maybe> InferDeviceAndStream(const UserOpExpr& user_op_expr, const LocalTensorMetaInferArgs& infer_args, - std::vector* output_tensor_metas) { + std::vector* output_tensor_metas) { if (!user_op_expr.device_and_stream_infer_fn()) { Symbol device = infer_args.input_local_tensor_metas().at(0)->device(); return GetDefaultStreamByDevice(device); @@ -150,10 +150,7 @@ Maybe LocalTensorMetaInferArgs::Init(const AttrMap& attrs, Symbol Maybe LocalTensorMetaInferArgs::InitInputLocalTensorMetas(const TensorTuple& input_tensors) { for (int i = 0; i < input_tensors.size(); ++i) { - LocalTensorMeta* local_tensor_meta = - dynamic_cast(input_tensors.at(i)->mut_tensor_meta()); - CHECK_NOTNULL_OR_RETURN(local_tensor_meta); - input_local_tensor_metas_.at(i) = SymbolOf(*local_tensor_meta); + input_local_tensor_metas_.at(i) = JUST(input_tensors.at(i)->local_tensor_meta()); } return Maybe::Ok(); } @@ -166,7 +163,7 @@ Maybe LocalTensorMetaInferArgs::InitInputLocalTensorMetas(const TensorTupl auto result = std::make_unique(user_op_expr.output_size()); - std::vector output_mut_metas(user_op_expr.output_size()); + std::vector output_mut_metas(user_op_expr.output_size()); // Infer devices Symbol stream; if (!user_op_expr.has_device_and_stream_infer_fn()) { @@ -195,7 +192,10 @@ Maybe LocalTensorMetaInferArgs::InitInputLocalTensorMetas(const TensorTupl std::shared_ptr stride(new Stride(output_mut_metas.at(i).shape())); output_mut_metas.at(i).set_stride(stride); } - output_metas->at(i) = SymbolOf(output_mut_metas.at(i)); + output_metas->at(i) = SymbolOf( + LocalTensorMeta(output_mut_metas.at(i).shape_ptr(), output_mut_metas.at(i).stride_ptr(), + output_mut_metas.at(i).data_type(), output_mut_metas.at(i).device(), + output_mut_metas.at(i).storage_offset())); } return std::shared_ptr(std::move(result)); } diff --git a/oneflow/core/framework/local_tensor_infer_cache.h b/oneflow/core/framework/local_tensor_infer_cache.h index 58d745bd987..055a5edb4a1 100644 --- a/oneflow/core/framework/local_tensor_infer_cache.h +++ b/oneflow/core/framework/local_tensor_infer_cache.h @@ -23,7 +23,7 @@ limitations under the License. #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/framework/stream.h" -#include "oneflow/core/framework/tensor_meta.h" +#include "oneflow/core/common/tensor_meta.h" namespace oneflow { From db6cb63a06a2bd733a94f1c6f05036a5739ff57d Mon Sep 17 00:00:00 2001 From: lixinqi Date: Mon, 18 Jul 2022 13:08:45 +0800 Subject: [PATCH 22/67] refactor SoftSync --- .../core/framework/instructions_builder.cpp | 33 ++++++++----------- 1 file changed, 13 insertions(+), 20 deletions(-) diff --git a/oneflow/core/framework/instructions_builder.cpp b/oneflow/core/framework/instructions_builder.cpp index 00c95a9f6f7..8115846bb06 100644 --- a/oneflow/core/framework/instructions_builder.cpp +++ b/oneflow/core/framework/instructions_builder.cpp @@ -451,10 +451,10 @@ Maybe ForEachEagerBlobObjectsNeedingSoftSync( if (unlikely(!opt_last_used_stream.has_value())) { continue; } const auto& last_used_stream = JUST(opt_last_used_stream); if (last_used_stream != stream) { - const auto& ForEachEagerBlobObject = [&](const auto& DoEachEagerBlobObject) -> Maybe { - return DoEachEagerBlobObject(eager_blob_object); - }; - JUST(DoEach(last_used_stream, ForEachEagerBlobObject)); + small_vector, kOpArgsReservedSize> dep_objects{ + intrusive::shared_ptr( + JUST(eager_blob_object->compute_local_dep_object()))}; + JUST(DoEach(last_used_stream, std::move(dep_objects))); } } } else { @@ -466,17 +466,15 @@ Maybe ForEachEagerBlobObjectsNeedingSoftSync( if (last_used_stream != stream) { SmallSetInsert(&last_used_streams, last_used_stream); } } for (const auto& last_used_stream : last_used_streams) { - const auto& ForEachEagerBlobObject = [&](const auto& DoEachEagerBlobObject) -> Maybe { - for (const auto& eager_blob_object : eager_blob_objects) { - const auto& opt_stream = eager_blob_object->last_used_stream(); - if (unlikely(!opt_stream.has_value())) { continue; } - if (JUST(opt_stream) == last_used_stream) { - JUST(DoEachEagerBlobObject(eager_blob_object)); - } + small_vector, kOpArgsReservedSize> dep_objects{}; + for (const auto& eager_blob_object : eager_blob_objects) { + const auto& opt_stream = eager_blob_object->last_used_stream(); + if (unlikely(!opt_stream.has_value())) { continue; } + if (JUST(opt_stream) == last_used_stream) { + dep_objects.emplace_back(JUST(eager_blob_object->compute_local_dep_object())); } - return Maybe::Ok(); - }; - JUST(DoEach(last_used_stream, ForEachEagerBlobObject)); + } + JUST(DoEach(last_used_stream, std::move(dep_objects))); } } return Maybe::Ok(); @@ -488,12 +486,7 @@ Maybe InstructionsBuilder::SoftSyncStream(const vm::EagerBlobObjectList& e Symbol stream) { JUST(ForEachEagerBlobObjectsNeedingSoftSync( eager_blob_objects, stream, - [&](Symbol last_used_stream, const auto& ForEachEagerBlobObject) -> Maybe { - small_vector, kOpArgsReservedSize> dep_objects{}; - JUST(ForEachEagerBlobObject([&](const auto& eager_blob_object) -> Maybe { - dep_objects.emplace_back(JUST(eager_blob_object->compute_local_dep_object())); - return Maybe::Ok(); - })); + [&](Symbol last_used_stream, auto&& dep_objects) -> Maybe { return SoftSyncStream(std::move(dep_objects), "mut", last_used_stream); })); for (const auto& eager_blob_object : eager_blob_objects) { From 02f7bf3911ed405fba2e2928b8e1ba72ebf71048 Mon Sep 17 00:00:00 2001 From: lixinqi Date: Mon, 18 Jul 2022 13:21:38 +0800 Subject: [PATCH 23/67] move SmallVector from common/container_util.h to framework/instructions_builder.cpp --- oneflow/core/common/container_util.h | 12 ------------ oneflow/core/framework/instructions_builder.cpp | 12 ++++++++++++ 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/oneflow/core/common/container_util.h b/oneflow/core/common/container_util.h index 2f8013b0e4f..9a837094726 100644 --- a/oneflow/core/common/container_util.h +++ b/oneflow/core/common/container_util.h @@ -82,18 +82,6 @@ std::string Join(const T& con, const std::string& delimiter) { return os.str(); } -template -using SmallSet = std::vector; - -template -std::pair::iterator, bool> SmallSetInsert(SmallSet* vec, const T& elem) { - for (auto iter = vec->begin(); iter != vec->end(); ++iter) { - if (*iter == elem) { return std::make_pair(iter, false); } - } - vec->push_back(elem); - return std::make_pair(--vec->end(), true); -} - } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_CONTAINER_UTIL_H_ diff --git a/oneflow/core/framework/instructions_builder.cpp b/oneflow/core/framework/instructions_builder.cpp index 8115846bb06..b6e4ed30dfe 100644 --- a/oneflow/core/framework/instructions_builder.cpp +++ b/oneflow/core/framework/instructions_builder.cpp @@ -441,6 +441,18 @@ Maybe InstructionsBuilder::TouchTensors(const vm::EagerBlobObjectListPtr& namespace { +template +using SmallSet = small_vector; + +template +std::pair::iterator, bool> SmallSetInsert(SmallSet* vec, const T& elem) { + for (auto iter = vec->begin(); iter != vec->end(); ++iter) { + if (*iter == elem) { return std::make_pair(iter, false); } + } + vec->push_back(elem); + return std::make_pair(vec->end() - 1, true); +} + template Maybe ForEachEagerBlobObjectsNeedingSoftSync( const vm::EagerBlobObjectList& eager_blob_objects, Symbol stream, From a88fd428b18ece32736db635d7110bcd89adfe00 Mon Sep 17 00:00:00 2001 From: clackhan Date: Mon, 18 Jul 2022 15:31:29 +0800 Subject: [PATCH 24/67] mone ONEFLOW_EAGER_ENABLE_LOCAL_INFER_CACHE to eager.h --- oneflow/core/common/env_var/eager.h | 28 +++++++++++++++++++ .../framework/local_tensor_infer_cache.cpp | 11 ++------ 2 files changed, 31 insertions(+), 8 deletions(-) create mode 100644 oneflow/core/common/env_var/eager.h diff --git a/oneflow/core/common/env_var/eager.h b/oneflow/core/common/env_var/eager.h new file mode 100644 index 00000000000..b7df99d310f --- /dev/null +++ b/oneflow/core/common/env_var/eager.h @@ -0,0 +1,28 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_CORE_COMMON_ENV_VAR_EAGER_H_ +#define ONEFLOW_CORE_COMMON_ENV_VAR_EAGER_H_ + +#include "oneflow/core/common/env_var/env_var.h" + +namespace oneflow { + +// NOTE: use env variable 'ONEFLOW_EAGER_ENABLE_LOCAL_INFER_CACHE' indicate whether the +// use infer cache in naive local op interpret. +DEFINE_THREAD_LOCAL_ENV_BOOL(ONEFLOW_EAGER_ENABLE_LOCAL_INFER_CACHE, true); + +} // namespace oneflow +#endif // ONEFLOW_CORE_COMMON_ENV_VAR_EAGER_H_ \ No newline at end of file diff --git a/oneflow/core/framework/local_tensor_infer_cache.cpp b/oneflow/core/framework/local_tensor_infer_cache.cpp index eabbd06ee36..5d090c2edc4 100644 --- a/oneflow/core/framework/local_tensor_infer_cache.cpp +++ b/oneflow/core/framework/local_tensor_infer_cache.cpp @@ -19,7 +19,7 @@ limitations under the License. #include "oneflow/core/operator/operator.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/common/container_util.h" -#include "oneflow/core/common/util.h" +#include "oneflow/core/common/env_var/eager.h" #include "oneflow/core/framework/infer_util.h" namespace oneflow { @@ -27,12 +27,6 @@ namespace one { namespace { -// NOTE: use env variable 'ONEFLOW_EAGER_ENABLE_LOCAL_INFER_CACHE' indicate whether the -// use infer cache in naive local op interpret. -bool ParseEnableEagerLocalInferCache() { - return ParseBooleanFromEnv("ONEFLOW_EAGER_ENABLE_LOCAL_INFER_CACHE", true); -} - Maybe CheckIsDeviceSupportedByOp(const Device& device, const std::string& op_type_name) { if (IsCpuOnly(op_type_name)) { CHECK_EQ_OR_RETURN(device.type(), "cpu"); } return Maybe::Ok(); @@ -202,7 +196,8 @@ Maybe LocalTensorMetaInferArgs::InitInputLocalTensorMetas(const TensorTupl Maybe LocalTensorInferCache::GetOrInfer( const LocalTensorMetaInferArgs& infer_args) { - static bool enable_eager_local_infer_cache = ParseEnableEagerLocalInferCache(); + static bool enable_eager_local_infer_cache = + ThreadLocalEnvBool(); if (enable_eager_local_infer_cache) { auto iter = cache_.find(infer_args); if (iter == cache_.end()) { From efd48c66f5b725c8d71a798b0aee0ea0bad5ccd0 Mon Sep 17 00:00:00 2001 From: clackhan Date: Mon, 18 Jul 2022 15:38:54 +0800 Subject: [PATCH 25/67] add blank line --- oneflow/core/common/env_var/eager.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/oneflow/core/common/env_var/eager.h b/oneflow/core/common/env_var/eager.h index b7df99d310f..ad7108ceb2d 100644 --- a/oneflow/core/common/env_var/eager.h +++ b/oneflow/core/common/env_var/eager.h @@ -25,4 +25,4 @@ namespace oneflow { DEFINE_THREAD_LOCAL_ENV_BOOL(ONEFLOW_EAGER_ENABLE_LOCAL_INFER_CACHE, true); } // namespace oneflow -#endif // ONEFLOW_CORE_COMMON_ENV_VAR_EAGER_H_ \ No newline at end of file +#endif // ONEFLOW_CORE_COMMON_ENV_VAR_EAGER_H_ From 4ed595c1ebb86872ee9280db8759cc6e4e2669f8 Mon Sep 17 00:00:00 2001 From: clackhan Date: Mon, 18 Jul 2022 16:56:56 +0800 Subject: [PATCH 26/67] reslove comments --- .../framework/local_tensor_infer_cache.cpp | 67 +++++++++---------- 1 file changed, 30 insertions(+), 37 deletions(-) diff --git a/oneflow/core/framework/local_tensor_infer_cache.cpp b/oneflow/core/framework/local_tensor_infer_cache.cpp index 5d090c2edc4..384a1ec00f7 100644 --- a/oneflow/core/framework/local_tensor_infer_cache.cpp +++ b/oneflow/core/framework/local_tensor_infer_cache.cpp @@ -51,20 +51,12 @@ Maybe CheckInputDeviceIdentical(const LocalTensorMetaInferArgs& infer_args class UserOpExprDeviceAndStreamInferContext final : public user_op::DeviceAndStreamInferContext { public: UserOpExprDeviceAndStreamInferContext(const UserOpExpr* user_op_expr, - const LocalTensorMetaInferArgs* infer_args, - std::vector* output_tensor_metas) + const LocalTensorMetaInferArgs& infer_args, + std::vector& output_tensor_metas) : user_op_expr_(user_op_expr), - composed_attrs_(infer_args->attrs(), user_op_expr->base_attrs()), - in_tensor_devices_(user_op_expr_->input_size()), - out_tensor_devices_(user_op_expr_->output_size()) { - for (int i = 0; i < user_op_expr_->input_size(); ++i) { - const auto& device = infer_args->input_local_tensor_metas().at(i)->device(); - in_tensor_devices_.at(i) = device; - } - for (int i = 0; i < user_op_expr_->output_size(); ++i) { - out_tensor_devices_.at(i) = output_tensor_metas->at(i).mut_device(); - } - } + composed_attrs_(infer_args.attrs(), user_op_expr->base_attrs()), + infer_args_(infer_args), + output_tensor_metas_(output_tensor_metas) {} const std::vector>& inputs() const override { return user_op_expr_->indexed_input_pairs(); @@ -80,7 +72,7 @@ class UserOpExprDeviceAndStreamInferContext final : public user_op::DeviceAndStr int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); CHECK_GE(tuple_index, 0); CHECK_LT(tuple_index, user_op_expr_->output_size()); - return out_tensor_devices_.at(tuple_index); + return output_tensor_metas_.at(tuple_index).mut_device(); } Symbol InputTensorDevice4ArgNameAndIndex(const std::string& name, @@ -89,7 +81,7 @@ class UserOpExprDeviceAndStreamInferContext final : public user_op::DeviceAndStr int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); CHECK_GE(tuple_index, 0); CHECK_LT(tuple_index, user_op_expr_->input_size()); - return in_tensor_devices_.at(tuple_index); + return infer_args_.input_local_tensor_metas().at(tuple_index)->device(); } private: @@ -99,21 +91,32 @@ class UserOpExprDeviceAndStreamInferContext final : public user_op::DeviceAndStr } const UserOpExpr* user_op_expr_; const ComposedAttrMap composed_attrs_; - small_vector, kOpArgsReservedSize> in_tensor_devices_; - small_vector*, kOpArgsReservedSize> out_tensor_devices_; + const LocalTensorMetaInferArgs& infer_args_; + std::vector& output_tensor_metas_; }; Maybe> InferDeviceAndStream(const UserOpExpr& user_op_expr, + const Symbol& default_device, const LocalTensorMetaInferArgs& infer_args, - std::vector* output_tensor_metas) { - if (!user_op_expr.device_and_stream_infer_fn()) { - Symbol device = infer_args.input_local_tensor_metas().at(0)->device(); - return GetDefaultStreamByDevice(device); + std::vector& output_tensor_metas) { + Symbol stream; + if (!user_op_expr.has_device_and_stream_infer_fn()) { + stream = JUST(GetDefaultStreamByDevice(default_device)); + for (int i = 0; i < user_op_expr.output_size(); i++) { + auto& tensor_meta = output_tensor_metas.at(i); + *tensor_meta.mut_device() = default_device; + } } else { - UserOpExprDeviceAndStreamInferContext device_and_stream_ctx(&user_op_expr, &infer_args, - output_tensor_metas); - return TRY(user_op_expr.device_and_stream_infer_fn()(&device_and_stream_ctx)); + if (!user_op_expr.device_and_stream_infer_fn()) { + Symbol device = infer_args.input_local_tensor_metas().at(0)->device(); + stream = JUST(GetDefaultStreamByDevice(device)); + } else { + UserOpExprDeviceAndStreamInferContext device_and_stream_ctx(&user_op_expr, infer_args, + output_tensor_metas); + stream = JUST(user_op_expr.device_and_stream_infer_fn()(&device_and_stream_ctx)); + } } + return stream; } } // namespace @@ -162,16 +165,8 @@ Maybe LocalTensorMetaInferArgs::InitInputLocalTensorMetas(const TensorTupl std::vector output_mut_metas(user_op_expr.output_size()); // Infer devices - Symbol stream; - if (!user_op_expr.has_device_and_stream_infer_fn()) { - stream = JUST(GetDefaultStreamByDevice(default_device)); - for (int i = 0; i < user_op_expr.output_size(); i++) { - auto& tensor_meta = output_mut_metas.at(i); - *tensor_meta.mut_device() = default_device; - } - } else { - stream = JUST(InferDeviceAndStream(user_op_expr, infer_args, &output_mut_metas)); - } + Symbol stream = + JUST(InferDeviceAndStream(user_op_expr, default_device, infer_args, output_mut_metas)); result->set_stream(stream); { @@ -196,9 +191,7 @@ Maybe LocalTensorMetaInferArgs::InitInputLocalTensorMetas(const TensorTupl Maybe LocalTensorInferCache::GetOrInfer( const LocalTensorMetaInferArgs& infer_args) { - static bool enable_eager_local_infer_cache = - ThreadLocalEnvBool(); - if (enable_eager_local_infer_cache) { + if (ThreadLocalEnvBool()) { auto iter = cache_.find(infer_args); if (iter == cache_.end()) { const auto& user_op_expr = user_op_expr_.lock(); From 82bbec50ca524a336e1ea76e396407e17ab4d786 Mon Sep 17 00:00:00 2001 From: clackhan Date: Mon, 18 Jul 2022 16:59:46 +0800 Subject: [PATCH 27/67] minor fix --- oneflow/core/framework/local_tensor_infer_cache.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/oneflow/core/framework/local_tensor_infer_cache.cpp b/oneflow/core/framework/local_tensor_infer_cache.cpp index 384a1ec00f7..15626fcd154 100644 --- a/oneflow/core/framework/local_tensor_infer_cache.cpp +++ b/oneflow/core/framework/local_tensor_infer_cache.cpp @@ -178,13 +178,13 @@ Maybe LocalTensorMetaInferArgs::InitInputLocalTensorMetas(const TensorTupl [&](int32_t i) -> TensorMeta* { return &output_mut_metas.at(i); })); } - auto* output_metas = result->mut_output_tensor_metas(); + auto* mut_output_tensor_metas = result->mut_output_tensor_metas(); for (int32_t i = 0; i < user_op_expr.output_size(); ++i) { if (!JUST(user_op_expr.SupportNonContiguous())) { std::shared_ptr stride(new Stride(output_mut_metas.at(i).shape())); output_mut_metas.at(i).set_stride(stride); } - output_metas->at(i) = SymbolOf(output_mut_metas.at(i)); + mut_output_tensor_metas->at(i) = SymbolOf(output_mut_metas.at(i)); } return std::shared_ptr(std::move(result)); } From f747252c760ec006b53964c11fc83b4c64bb17dd Mon Sep 17 00:00:00 2001 From: clackhan Date: Mon, 18 Jul 2022 18:01:01 +0800 Subject: [PATCH 28/67] refine --- .../framework/local_tensor_infer_cache.cpp | 24 +++++++++---------- oneflow/core/framework/tensor_meta.h | 16 +++++++++++++ 2 files changed, 28 insertions(+), 12 deletions(-) diff --git a/oneflow/core/framework/local_tensor_infer_cache.cpp b/oneflow/core/framework/local_tensor_infer_cache.cpp index 15626fcd154..77e4c127b70 100644 --- a/oneflow/core/framework/local_tensor_infer_cache.cpp +++ b/oneflow/core/framework/local_tensor_infer_cache.cpp @@ -50,9 +50,9 @@ Maybe CheckInputDeviceIdentical(const LocalTensorMetaInferArgs& infer_args class UserOpExprDeviceAndStreamInferContext final : public user_op::DeviceAndStreamInferContext { public: - UserOpExprDeviceAndStreamInferContext(const UserOpExpr* user_op_expr, - const LocalTensorMetaInferArgs& infer_args, - std::vector& output_tensor_metas) + UserOpExprDeviceAndStreamInferContext( + const UserOpExpr* user_op_expr, const LocalTensorMetaInferArgs& infer_args, + small_vector* output_tensor_metas) : user_op_expr_(user_op_expr), composed_attrs_(infer_args.attrs(), user_op_expr->base_attrs()), infer_args_(infer_args), @@ -72,7 +72,7 @@ class UserOpExprDeviceAndStreamInferContext final : public user_op::DeviceAndStr int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); CHECK_GE(tuple_index, 0); CHECK_LT(tuple_index, user_op_expr_->output_size()); - return output_tensor_metas_.at(tuple_index).mut_device(); + return output_tensor_metas_->at(tuple_index).mut_device(); } Symbol InputTensorDevice4ArgNameAndIndex(const std::string& name, @@ -92,18 +92,18 @@ class UserOpExprDeviceAndStreamInferContext final : public user_op::DeviceAndStr const UserOpExpr* user_op_expr_; const ComposedAttrMap composed_attrs_; const LocalTensorMetaInferArgs& infer_args_; - std::vector& output_tensor_metas_; + small_vector* output_tensor_metas_; }; -Maybe> InferDeviceAndStream(const UserOpExpr& user_op_expr, - const Symbol& default_device, - const LocalTensorMetaInferArgs& infer_args, - std::vector& output_tensor_metas) { +Maybe> InferDeviceAndStream( + const UserOpExpr& user_op_expr, const Symbol& default_device, + const LocalTensorMetaInferArgs& infer_args, + small_vector* output_tensor_metas) { Symbol stream; if (!user_op_expr.has_device_and_stream_infer_fn()) { stream = JUST(GetDefaultStreamByDevice(default_device)); for (int i = 0; i < user_op_expr.output_size(); i++) { - auto& tensor_meta = output_tensor_metas.at(i); + auto& tensor_meta = output_tensor_metas->at(i); *tensor_meta.mut_device() = default_device; } } else { @@ -163,10 +163,10 @@ Maybe LocalTensorMetaInferArgs::InitInputLocalTensorMetas(const TensorTupl auto result = std::make_unique(user_op_expr.output_size()); - std::vector output_mut_metas(user_op_expr.output_size()); + small_vector output_mut_metas(user_op_expr.output_size()); // Infer devices Symbol stream = - JUST(InferDeviceAndStream(user_op_expr, default_device, infer_args, output_mut_metas)); + JUST(InferDeviceAndStream(user_op_expr, default_device, infer_args, &output_mut_metas)); result->set_stream(stream); { diff --git a/oneflow/core/framework/tensor_meta.h b/oneflow/core/framework/tensor_meta.h index df8bd82f815..32a0789c666 100644 --- a/oneflow/core/framework/tensor_meta.h +++ b/oneflow/core/framework/tensor_meta.h @@ -70,6 +70,15 @@ class TensorMeta : public user_op::TensorDesc { bool* mut_is_dynamic() override { return &is_dynamic_; } void set_is_dynamic(bool val) override { is_dynamic_ = val; } + protected: + TensorMeta& operator=(const TensorMeta& other) { + this->shape_ = std::make_shared(*other.shape_); + this->stride_ = std::make_shared(*other.stride_); + this->data_type_ = other.data_type_; + this->is_dynamic_ = other.is_dynamic_; + return *this; + } + private: std::shared_ptr shape_; std::shared_ptr stride_; @@ -97,6 +106,13 @@ class LocalTensorMeta : public TensorMeta { bool operator==(const LocalTensorMeta& other) const; size_t CalcHashValue() const; + LocalTensorMeta& operator=(const LocalTensorMeta& other) { + TensorMeta::operator=(other); + this->device_ = other.device_; + this->storage_offset_ = other.storage_offset_; + return *this; + } + private: Symbol device_; int64_t storage_offset_; From 88367cb6b0f0a92699a42c313bd046f01ea03c90 Mon Sep 17 00:00:00 2001 From: lixinqi Date: Mon, 18 Jul 2022 22:55:50 +0800 Subject: [PATCH 29/67] explicit scalar initialization --- oneflow/core/operator/operator.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/oneflow/core/operator/operator.cpp b/oneflow/core/operator/operator.cpp index 70ba577f1c9..a7f7eba9de0 100644 --- a/oneflow/core/operator/operator.cpp +++ b/oneflow/core/operator/operator.cpp @@ -1322,7 +1322,7 @@ Maybe Operator::ToOpAttribute(OpAttribute* op_attribute) const { } else { ParallelConf parallel_conf = pair.second->parallel_conf(); const auto MakeParallelDescSymbol = [¶llel_conf]() -> Maybe { - int64_t symbol_id; + int64_t symbol_id = 0; const auto BuildInstruction = [&symbol_id, ¶llel_conf](InstructionsBuilder* builder) -> Maybe { symbol_id = JUST(JUST(builder->GetParallelDescSymbol(parallel_conf))->symbol_id()); From f5fcbf072c5c8c8459bd14614d56ae3d40d6ec30 Mon Sep 17 00:00:00 2001 From: clackhan Date: Tue, 19 Jul 2022 14:38:10 +0800 Subject: [PATCH 30/67] fix static check error --- oneflow/core/framework/local_tensor_infer_cache.cpp | 6 +++--- .../op_interpreter/eager_local_op_interpreter.cpp | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/oneflow/core/framework/local_tensor_infer_cache.cpp b/oneflow/core/framework/local_tensor_infer_cache.cpp index 77e4c127b70..0858da8a655 100644 --- a/oneflow/core/framework/local_tensor_infer_cache.cpp +++ b/oneflow/core/framework/local_tensor_infer_cache.cpp @@ -28,7 +28,7 @@ namespace one { namespace { Maybe CheckIsDeviceSupportedByOp(const Device& device, const std::string& op_type_name) { - if (IsCpuOnly(op_type_name)) { CHECK_EQ_OR_RETURN(device.type(), "cpu"); } + if (IsCpuOnly(op_type_name)) { CHECK_EQ_OR_RETURN(device.type(), "cpu"); } // NOLINT return Maybe::Ok(); } @@ -149,7 +149,7 @@ Maybe LocalTensorMetaInferArgs::InitInputLocalTensorMetas(const TensorTupl for (int i = 0; i < input_tensors.size(); ++i) { LocalTensorMeta* local_tensor_meta = dynamic_cast(input_tensors.at(i)->mut_tensor_meta()); - CHECK_NOTNULL_OR_RETURN(local_tensor_meta); + CHECK_NOTNULL_OR_RETURN(local_tensor_meta); // NOLINT input_local_tensor_metas_.at(i) = SymbolOf(*local_tensor_meta); } return Maybe::Ok(); @@ -195,7 +195,7 @@ Maybe LocalTensorInferCache::GetOrInfer( auto iter = cache_.find(infer_args); if (iter == cache_.end()) { const auto& user_op_expr = user_op_expr_.lock(); - CHECK_OR_RETURN(static_cast(user_op_expr)); + CHECK_OR_RETURN(static_cast(user_op_expr)); // NOLINT const auto& output_tensor_metas = JUST(Infer(*user_op_expr, infer_args)); iter = cache_.emplace(infer_args, output_tensor_metas).first; } diff --git a/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp b/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp index 46b09e2cf68..9467b1c1799 100644 --- a/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp +++ b/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp @@ -69,7 +69,7 @@ Maybe TensorImpl4Tensor(const std::shared_ptr& te Maybe NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext& ctx) { OF_PROFILER_RANGE_GUARD("NaiveInterpret"); - CHECK_EQ_OR_RETURN(outputs->size(), user_op_expr.output_size()); + CHECK_EQ_OR_RETURN(outputs->size(), user_op_expr.output_size()); // NOLINT Symbol default_device = JUST(GetDefaultDevice(inputs, ctx)); std::shared_ptr result; { @@ -107,10 +107,10 @@ Maybe NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& in auto* tensor_impl = JUST(TensorImpl4Tensor(outputs->at(i))); // output i is inplaced. // check TensorMeta of infer result and TensorMeta of output i. - CHECK_OR_RETURN(tensor_impl->tensor_meta()->shape() == output_tensor_metas.at(i)->shape()); - CHECK_OR_RETURN(tensor_impl->tensor_meta()->dtype() == output_tensor_metas.at(i)->dtype()); + CHECK_OR_RETURN(tensor_impl->tensor_meta()->shape() == output_tensor_metas.at(i)->shape()); // NOLINT + CHECK_OR_RETURN(tensor_impl->tensor_meta()->dtype() == output_tensor_metas.at(i)->dtype()); // NOLINT bool has_eager_blob_object = JUST(outputs->at(i)->has_eager_blob_object()); - CHECK_OR_RETURN(has_eager_blob_object); + CHECK_OR_RETURN(has_eager_blob_object); // NOLINT output_eager_blob_objects.at(i) = JUST(outputs->at(i)->eager_blob_object()); // TODO(zhaoluyang):(thread_local TensorMeta set stride then check) // CHECK_OR_RETURN(tensor_impl->tensor_meta()->stride() == From faeb57e44e9e9bbae3a14faf5e4dba55e9b7af25 Mon Sep 17 00:00:00 2001 From: oneflow-ci-bot Date: Tue, 19 Jul 2022 06:50:20 +0000 Subject: [PATCH 31/67] auto format by CI --- .../framework/op_interpreter/eager_local_op_interpreter.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp b/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp index 9467b1c1799..b6c6eebc95e 100644 --- a/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp +++ b/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp @@ -107,8 +107,10 @@ Maybe NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& in auto* tensor_impl = JUST(TensorImpl4Tensor(outputs->at(i))); // output i is inplaced. // check TensorMeta of infer result and TensorMeta of output i. - CHECK_OR_RETURN(tensor_impl->tensor_meta()->shape() == output_tensor_metas.at(i)->shape()); // NOLINT - CHECK_OR_RETURN(tensor_impl->tensor_meta()->dtype() == output_tensor_metas.at(i)->dtype()); // NOLINT + CHECK_OR_RETURN(tensor_impl->tensor_meta()->shape() + == output_tensor_metas.at(i)->shape()); // NOLINT + CHECK_OR_RETURN(tensor_impl->tensor_meta()->dtype() + == output_tensor_metas.at(i)->dtype()); // NOLINT bool has_eager_blob_object = JUST(outputs->at(i)->has_eager_blob_object()); CHECK_OR_RETURN(has_eager_blob_object); // NOLINT output_eager_blob_objects.at(i) = JUST(outputs->at(i)->eager_blob_object()); From 358a018624db6ef00a5c5a1069e7f65be409cb56 Mon Sep 17 00:00:00 2001 From: clackhan Date: Tue, 19 Jul 2022 15:35:30 +0800 Subject: [PATCH 32/67] of_format --- .../framework/op_interpreter/eager_local_op_interpreter.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp b/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp index 9467b1c1799..2440adb0b0c 100644 --- a/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp +++ b/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp @@ -107,8 +107,10 @@ Maybe NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& in auto* tensor_impl = JUST(TensorImpl4Tensor(outputs->at(i))); // output i is inplaced. // check TensorMeta of infer result and TensorMeta of output i. - CHECK_OR_RETURN(tensor_impl->tensor_meta()->shape() == output_tensor_metas.at(i)->shape()); // NOLINT - CHECK_OR_RETURN(tensor_impl->tensor_meta()->dtype() == output_tensor_metas.at(i)->dtype()); // NOLINT + CHECK_OR_RETURN(tensor_impl->tensor_meta()->shape() // NOLINT + == output_tensor_metas.at(i)->shape()); // NOLINT + CHECK_OR_RETURN(tensor_impl->tensor_meta()->dtype() // NOLINT + == output_tensor_metas.at(i)->dtype()); // NOLINT bool has_eager_blob_object = JUST(outputs->at(i)->has_eager_blob_object()); CHECK_OR_RETURN(has_eager_blob_object); // NOLINT output_eager_blob_objects.at(i) = JUST(outputs->at(i)->eager_blob_object()); From 403620a3c4e2e4462445c5691c79bee626a9552f Mon Sep 17 00:00:00 2001 From: clackhan Date: Tue, 19 Jul 2022 15:48:24 +0800 Subject: [PATCH 33/67] reslove comment --- oneflow/core/framework/local_tensor_infer_cache.cpp | 4 ++-- .../framework/op_interpreter/eager_local_op_interpreter.cpp | 6 +++++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/oneflow/core/framework/local_tensor_infer_cache.cpp b/oneflow/core/framework/local_tensor_infer_cache.cpp index 0858da8a655..2bbc1aa03eb 100644 --- a/oneflow/core/framework/local_tensor_infer_cache.cpp +++ b/oneflow/core/framework/local_tensor_infer_cache.cpp @@ -124,9 +124,9 @@ Maybe> InferDeviceAndStream( size_t LocalTensorMetaInferArgs::hash_value() const { size_t hash_value = std::hash()(attrs_); HashCombine(&hash_value, std::hash>()(default_device_)); - const auto& tensor_meta_hash_functor = std::hash(); + const auto& tensor_meta_hash_functor = std::hash>(); for (const auto& tensor_meta : input_local_tensor_metas_) { - HashCombine(&hash_value, tensor_meta_hash_functor(*tensor_meta)); + HashCombine(&hash_value, tensor_meta_hash_functor(tensor_meta)); } return hash_value; } diff --git a/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp b/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp index 2440adb0b0c..8dc7e9ead2b 100644 --- a/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp +++ b/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp @@ -48,12 +48,16 @@ namespace one { namespace { +Maybe> RawGetDefaultCpuDevice() { return Device::New("cpu", 0); } + +constexpr auto* GetDefaultCpuDevice = DECORATE(&RawGetDefaultCpuDevice, ThreadLocal); + Maybe> GetDefaultDevice(const TensorTuple& inputs, const OpExprInterpContext& ctx) { if (inputs.empty()) { if (ctx.device.has_value()) { return JUST(ctx.device); } else { - return Device::New("cpu", 0); + return GetDefaultCpuDevice(); } } return JUST(inputs.at(0)->device()); From ebddd178866a10f3dc1d93fdf7e2f22cd815d9f5 Mon Sep 17 00:00:00 2001 From: clackhan Date: Tue, 19 Jul 2022 17:08:00 +0800 Subject: [PATCH 34/67] refine --- .../framework/local_tensor_infer_cache.cpp | 19 +++++++++---------- .../core/framework/local_tensor_infer_cache.h | 16 ++++++++-------- 2 files changed, 17 insertions(+), 18 deletions(-) diff --git a/oneflow/core/framework/local_tensor_infer_cache.cpp b/oneflow/core/framework/local_tensor_infer_cache.cpp index 2bbc1aa03eb..e4c246d5837 100644 --- a/oneflow/core/framework/local_tensor_infer_cache.cpp +++ b/oneflow/core/framework/local_tensor_infer_cache.cpp @@ -34,7 +34,6 @@ Maybe CheckIsDeviceSupportedByOp(const Device& device, const std::string& Maybe CheckInputDeviceIdentical(const LocalTensorMetaInferArgs& infer_args, Symbol default_device) { - if (infer_args.input_local_tensor_metas().empty()) { return Maybe::Ok(); } for (int i = 0; i < infer_args.input_local_tensor_metas().size(); ++i) { CHECK_OR_RETURN(default_device == JUST(VectorAt(infer_args.input_local_tensor_metas(), i))->device()) @@ -50,9 +49,9 @@ Maybe CheckInputDeviceIdentical(const LocalTensorMetaInferArgs& infer_args class UserOpExprDeviceAndStreamInferContext final : public user_op::DeviceAndStreamInferContext { public: - UserOpExprDeviceAndStreamInferContext( - const UserOpExpr* user_op_expr, const LocalTensorMetaInferArgs& infer_args, - small_vector* output_tensor_metas) + UserOpExprDeviceAndStreamInferContext(const UserOpExpr* user_op_expr, + const LocalTensorMetaInferArgs& infer_args, + OpArgsVector* output_tensor_metas) : user_op_expr_(user_op_expr), composed_attrs_(infer_args.attrs(), user_op_expr->base_attrs()), infer_args_(infer_args), @@ -92,13 +91,13 @@ class UserOpExprDeviceAndStreamInferContext final : public user_op::DeviceAndStr const UserOpExpr* user_op_expr_; const ComposedAttrMap composed_attrs_; const LocalTensorMetaInferArgs& infer_args_; - small_vector* output_tensor_metas_; + OpArgsVector* output_tensor_metas_; }; -Maybe> InferDeviceAndStream( - const UserOpExpr& user_op_expr, const Symbol& default_device, - const LocalTensorMetaInferArgs& infer_args, - small_vector* output_tensor_metas) { +Maybe> InferDeviceAndStream(const UserOpExpr& user_op_expr, + const Symbol& default_device, + const LocalTensorMetaInferArgs& infer_args, + OpArgsVector* output_tensor_metas) { Symbol stream; if (!user_op_expr.has_device_and_stream_infer_fn()) { stream = JUST(GetDefaultStreamByDevice(default_device)); @@ -163,7 +162,7 @@ Maybe LocalTensorMetaInferArgs::InitInputLocalTensorMetas(const TensorTupl auto result = std::make_unique(user_op_expr.output_size()); - small_vector output_mut_metas(user_op_expr.output_size()); + OpArgsVector output_mut_metas(user_op_expr.output_size()); // Infer devices Symbol stream = JUST(InferDeviceAndStream(user_op_expr, default_device, infer_args, &output_mut_metas)); diff --git a/oneflow/core/framework/local_tensor_infer_cache.h b/oneflow/core/framework/local_tensor_infer_cache.h index 58d745bd987..534278a2da5 100644 --- a/oneflow/core/framework/local_tensor_infer_cache.h +++ b/oneflow/core/framework/local_tensor_infer_cache.h @@ -31,6 +31,9 @@ class Device; namespace one { +template +using OpArgsVector = small_vector; + class TensorTuple; class UserOpExpr; @@ -41,8 +44,7 @@ class LocalTensorMetaInferArgs final { LocalTensorMetaInferArgs(LocalTensorMetaInferArgs&&) = default; ~LocalTensorMetaInferArgs() = default; - const small_vector, kOpArgsReservedSize>& input_local_tensor_metas() - const { + const OpArgsVector>& input_local_tensor_metas() const { return input_local_tensor_metas_; } const AttrMap& attrs() const { return attrs_; } @@ -61,7 +63,7 @@ class LocalTensorMetaInferArgs final { AttrMap attrs_; Symbol default_device_; - small_vector, kOpArgsReservedSize> input_local_tensor_metas_; + OpArgsVector> input_local_tensor_metas_; }; } // namespace one @@ -88,18 +90,16 @@ class LocalTensorInferResult final { LocalTensorInferResult(LocalTensorInferResult&&) = delete; ~LocalTensorInferResult() = default; - const small_vector, kOpArgsReservedSize>& output_tensor_metas() const { + const OpArgsVector>& output_tensor_metas() const { return output_tensor_metas_; } - small_vector, kOpArgsReservedSize>* mut_output_tensor_metas() { - return &output_tensor_metas_; - } + OpArgsVector>* mut_output_tensor_metas() { return &output_tensor_metas_; } const Symbol& stream() const { return stream_; } void set_stream(const Symbol& stream) { stream_ = stream; } private: - small_vector, kOpArgsReservedSize> output_tensor_metas_; + OpArgsVector> output_tensor_metas_; Symbol stream_; }; From 783ecb5e559763daa871f27c8eddbbdfd7cd1ff5 Mon Sep 17 00:00:00 2001 From: clackhan Date: Tue, 19 Jul 2022 17:22:19 +0800 Subject: [PATCH 35/67] refine --- .../op_interpreter/eager_local_op_interpreter.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp b/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp index 8dc7e9ead2b..635e9889ea9 100644 --- a/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp +++ b/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp @@ -75,12 +75,12 @@ Maybe NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& in OF_PROFILER_RANGE_GUARD("NaiveInterpret"); CHECK_EQ_OR_RETURN(outputs->size(), user_op_expr.output_size()); // NOLINT Symbol default_device = JUST(GetDefaultDevice(inputs, ctx)); - std::shared_ptr result; - { - LocalTensorMetaInferArgs infer_args; - JUST(infer_args.Init(ctx.attrs, default_device, inputs)); - result = JUST(user_op_expr.mut_local_tensor_infer_cache()->GetOrInfer(infer_args)); - } + const std::shared_ptr result = + JUST([&]() -> Maybe { + LocalTensorMetaInferArgs infer_args; + JUST(infer_args.Init(ctx.attrs, default_device, inputs)); + return JUST(user_op_expr.mut_local_tensor_infer_cache()->GetOrInfer(infer_args)); + }()); vm::EagerBlobObjectList input_eager_blob_objects(inputs.size()); for (int i = 0; i < inputs.size(); i++) { From a2b874b0af8d1542853d50d5be45e1eb30a969d3 Mon Sep 17 00:00:00 2001 From: clackhan Date: Tue, 19 Jul 2022 18:17:54 +0800 Subject: [PATCH 36/67] refine --- oneflow/core/framework/tensor_meta.h | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/oneflow/core/framework/tensor_meta.h b/oneflow/core/framework/tensor_meta.h index 32a0789c666..1316706bba9 100644 --- a/oneflow/core/framework/tensor_meta.h +++ b/oneflow/core/framework/tensor_meta.h @@ -106,12 +106,7 @@ class LocalTensorMeta : public TensorMeta { bool operator==(const LocalTensorMeta& other) const; size_t CalcHashValue() const; - LocalTensorMeta& operator=(const LocalTensorMeta& other) { - TensorMeta::operator=(other); - this->device_ = other.device_; - this->storage_offset_ = other.storage_offset_; - return *this; - } + LocalTensorMeta& operator=(const LocalTensorMeta& other) = default; private: Symbol device_; From a3c6f57af2a50f6b59feef1e5d51574d1ba33ab5 Mon Sep 17 00:00:00 2001 From: clackhan Date: Wed, 20 Jul 2022 11:14:13 +0800 Subject: [PATCH 37/67] fix error --- oneflow/core/framework/tensor_impl.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/oneflow/core/framework/tensor_impl.cpp b/oneflow/core/framework/tensor_impl.cpp index 9999604dfeb..bf882e53c26 100644 --- a/oneflow/core/framework/tensor_impl.cpp +++ b/oneflow/core/framework/tensor_impl.cpp @@ -132,8 +132,7 @@ Maybe EagerLocalTensorImpl::is_pinned() const { Maybe EagerLocalTensorImpl::set_eager_blob_object( std::shared_ptr eager_blob_object) { eager_blob_object_ = eager_blob_object; - CHECK_OR_RETURN(eager_blob_object_->shape_ptr().get() == tensor_meta()->shape_ptr().get()) - << kOfBugIssueUploadPrompt; + CHECK_OR_RETURN(eager_blob_object_->shape() == tensor_meta()->shape()) << kOfBugIssueUploadPrompt; CHECK_OR_RETURN(eager_blob_object_->data_type() == tensor_meta()->dtype()) << kOfBugIssueUploadPrompt; JUST(UpdateTensorStorage()); From d25f71d0789ec798ab519668ab5832e24b232e05 Mon Sep 17 00:00:00 2001 From: clackhan Date: Wed, 20 Jul 2022 16:33:51 +0800 Subject: [PATCH 38/67] define MutOutputShape and MutOutputStride in InferContext --- oneflow/core/framework/infer_util.cpp | 2 +- oneflow/core/framework/infer_util.h | 12 ++-- oneflow/core/framework/op_expr.cpp | 34 +++++++++-- oneflow/core/framework/op_kernel.cpp | 2 +- oneflow/core/kernel/user_kernel.cpp | 28 +++++++-- oneflow/core/operator/user_op.cpp | 38 +++++++++--- oneflow/ir/oneflow-extension/extension.cpp | 2 +- ...ttention_query_mul_key_and_value_kernel.cu | 4 +- ...random_batch_permutation_indices_kernel.cu | 4 +- .../kernels/nccl_logical_send_recv_kernel.cpp | 4 +- oneflow/user/kernels/nms_kernel.cu | 4 +- oneflow/user/kernels/stateful_opkernel.cpp | 58 ++++++++++++++----- .../user/kernels/two_stage_reduce_kernel.cpp | 8 +-- .../kernels/unsorted_segment_sum_kernel.cpp | 4 +- oneflow/user/kernels/where_kernel.cpp | 20 +++---- oneflow/user/ops/acc_op.cpp | 2 +- oneflow/user/ops/adaptive_pool_op.cpp | 4 +- oneflow/user/ops/arange_op.cpp | 4 +- oneflow/user/ops/arg_sort_op.cpp | 2 +- oneflow/user/ops/argmax_op.cpp | 2 +- oneflow/user/ops/avg_pool_op.cpp | 4 +- oneflow/user/ops/bias_add_op.cpp | 2 +- oneflow/user/ops/broadcast_div_grad_op.cpp | 2 +- oneflow/user/ops/broadcast_like_op.cpp | 4 +- oneflow/user/ops/broadcast_pow_grad_op.cpp | 4 +- oneflow/user/ops/buffer_op.cpp | 2 +- oneflow/user/ops/cast_like_op.cpp | 2 +- oneflow/user/ops/cast_to_tick_op.cpp | 2 +- .../ops/categorical_ordinal_encode_op.cpp | 4 +- oneflow/user/ops/celu_op.cpp | 4 +- oneflow/user/ops/clip_by_value_op.cpp | 4 +- oneflow/user/ops/combined_margin_loss_op.cpp | 4 +- oneflow/user/ops/constant_op.cpp | 4 +- oneflow/user/ops/conv_op.cpp | 2 +- oneflow/user/ops/copy_op.cpp | 4 +- oneflow/user/ops/ctc_loss_op.cpp | 10 ++-- .../cublas_bias_add_relu_matmul_grad_op.cpp | 4 +- .../cublas_fused_matmul_bias_add_grad_op.cpp | 4 +- oneflow/user/ops/cublas_fused_mlp_grad_op.cpp | 6 +- oneflow/user/ops/cublas_fused_mlp_op.cpp | 6 +- oneflow/user/ops/cum_ops.cpp | 6 +- oneflow/user/ops/data_shuffle_op.cpp | 24 ++++---- oneflow/user/ops/distributions/normal_op.cpp | 4 +- .../user/ops/distributions/uniform_int_op.cpp | 4 +- oneflow/user/ops/distributions/uniform_op.cpp | 4 +- oneflow/user/ops/dot_op.cpp | 2 +- oneflow/user/ops/dropout_op.cpp | 8 +-- oneflow/user/ops/eager_b_to_s_op.cpp | 2 +- oneflow/user/ops/eager_nccl_ops.cpp | 14 ++--- oneflow/user/ops/eager_p_to_b_op.cpp | 2 +- oneflow/user/ops/eager_p_to_s_op.cpp | 2 +- oneflow/user/ops/eager_s_to_b_op.cpp | 2 +- oneflow/user/ops/eager_s_to_p_op.cpp | 2 +- oneflow/user/ops/eager_s_to_s_op.cpp | 2 +- .../user/ops/eager_symmetric_s_to_p_op.cpp | 2 +- oneflow/user/ops/elu_op.cpp | 4 +- oneflow/user/ops/embedding_op.cpp | 2 +- oneflow/user/ops/empty_op.cpp | 8 +-- oneflow/user/ops/erfinv_op.cpp | 2 +- oneflow/user/ops/expand_dims_op.cpp | 2 +- oneflow/user/ops/expand_op.cpp | 4 +- oneflow/user/ops/eye_op.cpp | 2 +- oneflow/user/ops/fake_quantization_op.cpp | 2 +- oneflow/user/ops/fill_op.cpp | 8 +-- oneflow/user/ops/fused_bias_add_op.cpp | 6 +- .../fused_cross_feature_interaction_op.cpp | 20 +++---- .../ops/fused_dot_feature_interaction_op.cpp | 10 ++-- oneflow/user/ops/fused_gru_cell_op.cpp | 14 ++--- oneflow/user/ops/fused_lstm_cell_op.cpp | 14 +++-- .../fused_matmul_bias_add_relu_dropout_op.cpp | 6 +- .../user/ops/fused_relu_dropout_grad_op.cpp | 2 +- .../fused_scale_mask_softmax_dropout_op.cpp | 4 +- .../user/ops/fused_scale_mask_softmax_op.cpp | 2 +- ...fused_scale_tril_softmax_mask_scale_op.cpp | 4 +- ..._attention_query_mul_key_and_value_ops.cpp | 6 +- oneflow/user/ops/gelu_op.cpp | 4 +- ...te_random_batch_permutation_indices_op.cpp | 2 +- oneflow/user/ops/hardshrink_op.cpp | 4 +- oneflow/user/ops/hardsigmoid_op.cpp | 4 +- oneflow/user/ops/hardswish_op.cpp | 4 +- oneflow/user/ops/hardtanh_op.cpp | 4 +- .../ops/hierarchical_parallel_cast_op.cpp | 4 +- oneflow/user/ops/identity_op.cpp | 2 +- .../user/ops/image_object_preprocess_ops.cpp | 14 ++--- oneflow/user/ops/image_preprocess_ops.cpp | 2 +- .../user/ops/l1_l2_regularize_gradient_op.cpp | 2 +- oneflow/user/ops/l2_normalize_op.cpp | 6 +- oneflow/user/ops/leaky_relu_op.cpp | 4 +- oneflow/user/ops/log_softmax_op.cpp | 4 +- oneflow/user/ops/masked_fill_op.cpp | 2 +- .../user/ops/math_binary_broadcast_ops.cpp | 10 ++-- oneflow/user/ops/matmul_op.cpp | 2 +- oneflow/user/ops/matrix_vector_product_op.cpp | 6 +- oneflow/user/ops/median_op.cpp | 2 +- oneflow/user/ops/median_with_indices_op.cpp | 4 +- oneflow/user/ops/min_max_observer_op.cpp | 12 ++-- oneflow/user/ops/mish_op.cpp | 4 +- oneflow/user/ops/model_update_ops.cpp | 2 +- .../moving_average_min_max_observer_op.cpp | 4 +- oneflow/user/ops/multi_reduce_ops.cpp | 6 +- oneflow/user/ops/narrow_op.cpp | 2 +- oneflow/user/ops/nccl_logical_2d_sbp_ops.cpp | 10 ++-- oneflow/user/ops/nccl_logical_ops.cpp | 14 ++--- oneflow/user/ops/nd_index_slice_ops.cpp | 8 +-- oneflow/user/ops/nms_op.cpp | 2 +- oneflow/user/ops/nvtx_range_op.cpp | 4 +- oneflow/user/ops/one_embedding_ops.cpp | 22 +++---- oneflow/user/ops/ones_like_op.cpp | 4 +- oneflow/user/ops/p2p_comm_op.cpp | 2 +- oneflow/user/ops/pad_op.cpp | 2 +- oneflow/user/ops/padding_ops.cpp | 8 +-- oneflow/user/ops/parallel_cast_op.cpp | 2 +- oneflow/user/ops/partial_fc_sample_op.cpp | 4 +- oneflow/user/ops/prelu_op.cpp | 6 +- oneflow/user/ops/quantization_op.cpp | 2 +- oneflow/user/ops/randperm_op.cpp | 4 +- oneflow/user/ops/reduce_ops.cpp | 4 +- oneflow/user/ops/relu_op.cpp | 4 +- oneflow/user/ops/repeat_op.cpp | 2 +- oneflow/user/ops/reshape_like_op.cpp | 2 +- oneflow/user/ops/roi_align_op.cpp | 4 +- oneflow/user/ops/roll_op.cpp | 2 +- oneflow/user/ops/same_padding_op.cpp | 2 +- oneflow/user/ops/scalar_logical_op.cpp | 2 +- oneflow/user/ops/scalar_math_op.cpp | 6 +- oneflow/user/ops/search_sorted_op.cpp | 4 +- oneflow/user/ops/selu_op.cpp | 4 +- oneflow/user/ops/silu_op.cpp | 4 +- oneflow/user/ops/slice_op.cpp | 6 +- oneflow/user/ops/softmax_cross_entropy_op.cpp | 4 +- oneflow/user/ops/softmax_op.cpp | 4 +- oneflow/user/ops/softplus_op.cpp | 4 +- oneflow/user/ops/softshrink_op.cpp | 4 +- oneflow/user/ops/softsign_op.cpp | 4 +- oneflow/user/ops/sort_op.cpp | 2 +- oneflow/user/ops/sparse_cross_entropy_op.cpp | 2 +- .../ops/sparse_softmax_cross_entropy_op.cpp | 4 +- oneflow/user/ops/squeeze_op.cpp | 2 +- oneflow/user/ops/ssp_variable_proxy_op.cpp | 4 +- oneflow/user/ops/tf_pool_op.cpp | 2 +- oneflow/user/ops/tf_prelu_op.cpp | 2 +- oneflow/user/ops/threshold_op.cpp | 4 +- oneflow/user/ops/to_contiguous_op.cpp | 4 +- oneflow/user/ops/top_k_op.cpp | 2 +- oneflow/user/ops/tuple_identity_op.cpp | 2 +- oneflow/user/ops/two_stage_reduce_ops.cpp | 20 +++---- oneflow/user/ops/unfold_fold_op.cpp | 4 +- oneflow/user/ops/unfold_tensor_op.cpp | 2 +- oneflow/user/ops/unsorted_segment_sum_op.cpp | 4 +- oneflow/user/ops/upsample_op.cpp | 14 ++--- oneflow/user/ops/util_ops.cpp | 4 +- oneflow/user/ops/variance_op.cpp | 2 +- oneflow/user/ops/vector_matrix_product_op.cpp | 6 +- oneflow/user/ops/where_op.cpp | 14 ++--- oneflow/user/ops/zero_like_op.cpp | 2 +- 155 files changed, 497 insertions(+), 399 deletions(-) diff --git a/oneflow/core/framework/infer_util.cpp b/oneflow/core/framework/infer_util.cpp index 599f6a9070d..4ccd9ca7955 100644 --- a/oneflow/core/framework/infer_util.cpp +++ b/oneflow/core/framework/infer_util.cpp @@ -40,7 +40,7 @@ Maybe TensorDescInferFnUtil::Unchanged(InferContext* ctx) { for (size_t i = 0; i < ctx->outputs().size(); ++i) { const std::pair& output_arg = ctx->outputs().at(i); *ctx->OutputIsDynamic(output_arg.first, output_arg.second) = first_tensor_desc->is_dynamic(); - *ctx->OutputShape(output_arg.first, output_arg.second) = first_tensor_desc->shape(); + *ctx->MutOutputShape(output_arg.first, output_arg.second) = first_tensor_desc->shape(); } return Maybe::Ok(); } diff --git a/oneflow/core/framework/infer_util.h b/oneflow/core/framework/infer_util.h index 1fcb07a7590..960d9bba9bf 100644 --- a/oneflow/core/framework/infer_util.h +++ b/oneflow/core/framework/infer_util.h @@ -43,11 +43,15 @@ class InferContext { virtual const TensorDesc* LogicalTensorDesc4ArgNameAndIndex(const std::string&, int32_t) const = 0; virtual const Shape& InputShape(const std::string&, int32_t) const = 0; - virtual Shape* OutputShape(const std::string&, int32_t) = 0; - virtual Shape* Shape4ArgNameAndIndex(const std::string&, int32_t) = 0; + virtual const Shape& OutputShape(const std::string&, int32_t) const = 0; + virtual Shape* MutOutputShape(const std::string&, int32_t) = 0; + virtual const Shape& Shape4ArgNameAndIndex(const std::string&, int32_t) const = 0; + virtual Shape* MutShape4ArgNameAndIndex(const std::string&, int32_t) = 0; virtual const Stride& InputStride(const std::string&, int32_t) const = 0; - virtual Stride* OutputStride(const std::string&, int32_t) = 0; - virtual Stride* Stride4ArgNameAndIndex(const std::string&, int32_t) = 0; + virtual const Stride& OutputStride(const std::string&, int32_t) const = 0; + virtual Stride* MutOutputStride(const std::string&, int32_t) = 0; + virtual const Stride& Stride4ArgNameAndIndex(const std::string&, int32_t) const = 0; + virtual Stride* MutStride4ArgNameAndIndex(const std::string&, int32_t) = 0; virtual const DataType& InputDType(const std::string&, int32_t) const = 0; virtual DataType* OutputDType(const std::string&, int32_t) = 0; virtual DataType* Dtype4ArgNameAndIndex(const std::string&, int32_t) = 0; diff --git a/oneflow/core/framework/op_expr.cpp b/oneflow/core/framework/op_expr.cpp index 13113237061..9e07d3f0ccc 100644 --- a/oneflow/core/framework/op_expr.cpp +++ b/oneflow/core/framework/op_expr.cpp @@ -221,14 +221,27 @@ class UserOpExprInferContext : public user_op::InferContext { return tensor_meta4input_index_(tuple_index)->shape(); } - Shape* OutputShape(const std::string& name, int32_t index) override { + const Shape& OutputShape(const std::string& name, int32_t index) const override { + const auto& arg_tuple = *user_op_expr_->output_arg_tuple(); + int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); + CHECK_GE(tuple_index, 0); + return tensor_meta4input_index_(tuple_index)->shape(); + } + + Shape* MutOutputShape(const std::string& name, int32_t index) override { const auto& arg_tuple = *user_op_expr_->output_arg_tuple(); int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); CHECK_GE(tuple_index, 0); return tensor_meta4output_index_(tuple_index)->mut_shape(); } - Shape* Shape4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { + const Shape& Shape4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { + return const_cast(this) + ->TensorDesc4ArgNameAndIndex(arg_name, index) + ->shape(); + } + + Shape* MutShape4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { return TensorDesc4ArgNameAndIndex(arg_name, index)->mut_shape(); } @@ -239,14 +252,27 @@ class UserOpExprInferContext : public user_op::InferContext { return tensor_meta4input_index_(tuple_index)->stride(); } - Stride* OutputStride(const std::string& name, int32_t index) override { + const Stride& OutputStride(const std::string& name, int32_t index) const override { + const auto& arg_tuple = *user_op_expr_->input_arg_tuple(); + int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); + CHECK_GE(tuple_index, 0); + return tensor_meta4input_index_(tuple_index)->stride(); + } + + Stride* MutOutputStride(const std::string& name, int32_t index) override { const auto& arg_tuple = *user_op_expr_->output_arg_tuple(); int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); CHECK_GE(tuple_index, 0); return tensor_meta4output_index_(tuple_index)->mut_stride(); } - Stride* Stride4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { + const Stride& Stride4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { + return const_cast(this) + ->TensorDesc4ArgNameAndIndex(arg_name, index) + ->stride(); + } + + Stride* MutStride4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { return TensorDesc4ArgNameAndIndex(arg_name, index)->mut_stride(); } diff --git a/oneflow/core/framework/op_kernel.cpp b/oneflow/core/framework/op_kernel.cpp index 73add18775f..cbbfc59f2d7 100644 --- a/oneflow/core/framework/op_kernel.cpp +++ b/oneflow/core/framework/op_kernel.cpp @@ -25,7 +25,7 @@ void OpKernel::InferShape(KernelInferContext* ctx) const { CHECK_NOTNULL(op_infer_ctx); ctx->GetOpInferFn()(op_infer_ctx); for (const auto& arg_pair : ctx->outputs()) { - const Shape& shape = *op_infer_ctx->OutputShape(arg_pair.first, arg_pair.second); + const Shape& shape = op_infer_ctx->OutputShape(arg_pair.first, arg_pair.second); auto mut_shape_view = ctx->MutShapeView4ArgNameAndIndex(arg_pair.first, arg_pair.second); mut_shape_view.set_shape(shape); } diff --git a/oneflow/core/kernel/user_kernel.cpp b/oneflow/core/kernel/user_kernel.cpp index 12c40c20d2a..0dd9a3c26d2 100644 --- a/oneflow/core/kernel/user_kernel.cpp +++ b/oneflow/core/kernel/user_kernel.cpp @@ -261,21 +261,37 @@ class UserKernelOpInferContext : public user_op::InferContext { return it->second.get(); } const Shape& InputShape(const std::string& arg_name, int32_t index) const override { - return *const_cast(this)->Shape4ArgNameAndIndex(arg_name, index); + return Shape4ArgNameAndIndex(arg_name, index); } - Shape* OutputShape(const std::string& arg_name, int32_t index) override { + const Shape& OutputShape(const std::string& arg_name, int32_t index) const override { return Shape4ArgNameAndIndex(arg_name, index); } - Shape* Shape4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { + Shape* MutOutputShape(const std::string& arg_name, int32_t index) override { + return MutShape4ArgNameAndIndex(arg_name, index); + } + const Shape& Shape4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { + return const_cast(this) + ->TensorDesc4ArgNameAndIndex(arg_name, index) + ->shape(); + } + Shape* MutShape4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { return TensorDesc4ArgNameAndIndex(arg_name, index)->mut_shape(); } const Stride& InputStride(const std::string& arg_name, int32_t index) const override { - return *const_cast(this)->Stride4ArgNameAndIndex(arg_name, index); + return Stride4ArgNameAndIndex(arg_name, index); } - Stride* OutputStride(const std::string& arg_name, int32_t index) override { + const Stride& OutputStride(const std::string& arg_name, int32_t index) const override { return Stride4ArgNameAndIndex(arg_name, index); } - Stride* Stride4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { + Stride* MutOutputStride(const std::string& arg_name, int32_t index) override { + return MutStride4ArgNameAndIndex(arg_name, index); + } + const Stride& Stride4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { + return const_cast(this) + ->TensorDesc4ArgNameAndIndex(arg_name, index) + ->stride(); + } + Stride* MutStride4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { return TensorDesc4ArgNameAndIndex(arg_name, index)->mut_stride(); } const DataType& InputDType(const std::string& arg_name, int32_t index) const override { diff --git a/oneflow/core/operator/user_op.cpp b/oneflow/core/operator/user_op.cpp index c334877fe83..d2cf798c25b 100644 --- a/oneflow/core/operator/user_op.cpp +++ b/oneflow/core/operator/user_op.cpp @@ -171,23 +171,45 @@ class UserOpInferContext final : public user_op::InferContext { } } const Shape& InputShape(const std::string& arg_name, int32_t index) const override { - return *const_cast(this)->Shape4ArgNameAndIndex(arg_name, index); + return Shape4ArgNameAndIndex(arg_name, index); } - Shape* OutputShape(const std::string& arg_name, int32_t index) override { + const Shape& OutputShape(const std::string& arg_name, int32_t index) const override { return Shape4ArgNameAndIndex(arg_name, index); } - Shape* Shape4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { + Shape* MutOutputShape(const std::string& arg_name, int32_t index) override { + return MutShape4ArgNameAndIndex(arg_name, index); + } + const Shape& Shape4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { + auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index)); + if (it == arg2tensor_desc_.end()) { + thread_local static Shape non_shape; + return non_shape; + }; + return it->second.shape(); + } + Shape* MutShape4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index)); if (it == arg2tensor_desc_.end()) { return nullptr; }; return it->second.mut_shape(); } const Stride& InputStride(const std::string& arg_name, int32_t index) const override { - return *const_cast(this)->Stride4ArgNameAndIndex(arg_name, index); + return Stride4ArgNameAndIndex(arg_name, index); } - Stride* OutputStride(const std::string& arg_name, int32_t index) override { + const Stride& OutputStride(const std::string& arg_name, int32_t index) const override { return Stride4ArgNameAndIndex(arg_name, index); } - Stride* Stride4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { + Stride* MutOutputStride(const std::string& arg_name, int32_t index) override { + return MutStride4ArgNameAndIndex(arg_name, index); + } + const Stride& Stride4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { + auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index)); + if (it == arg2tensor_desc_.end()) { + thread_local static Stride non_stride; + return non_stride; + }; + return it->second.stride(); + } + Stride* MutStride4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index)); if (it == arg2tensor_desc_.end()) { return nullptr; }; return it->second.mut_stride(); @@ -612,8 +634,8 @@ Maybe UserOp::InferOutBlobDescs( for (const auto& pair : infer_ctx.outputs()) { BlobDesc* out_blob_desc = GetBlobDesc4BnInOp(GenRepeatedBn(pair.first, pair.second)); out_blob_desc->set_data_type(*(infer_ctx.OutputDType(pair.first, pair.second))); - out_blob_desc->mut_shape() = *(infer_ctx.OutputShape(pair.first, pair.second)); - out_blob_desc->mut_stride() = Stride(*(infer_ctx.OutputShape(pair.first, pair.second))); + out_blob_desc->mut_shape() = infer_ctx.OutputShape(pair.first, pair.second); + out_blob_desc->mut_stride() = Stride(infer_ctx.OutputShape(pair.first, pair.second)); out_blob_desc->set_is_dynamic(*infer_ctx.OutputIsDynamic(pair.first, pair.second)); } return Maybe::Ok(); diff --git a/oneflow/ir/oneflow-extension/extension.cpp b/oneflow/ir/oneflow-extension/extension.cpp index 9954ed6dd8d..78d574b4376 100644 --- a/oneflow/ir/oneflow-extension/extension.cpp +++ b/oneflow/ir/oneflow-extension/extension.cpp @@ -49,7 +49,7 @@ REGISTER_USER_OP("mlir_jit") CHECK_EQ(ctx->inputs().size(), 2); CHECK_EQ(ctx->outputs().size(), 1); const Shape& in_shape = ctx->InputShape("in", 0); - Shape* out_shape = ctx->OutputShape("out", 0); + Shape* out_shape = ctx->MutOutputShape("out", 0); *out_shape = in_shape; *ctx->OutputDType("out", 0) = ctx->InputDType("in", 1); return Maybe::Ok(); diff --git a/oneflow/user/kernels/fused_self_attention_query_mul_key_and_value_kernel.cu b/oneflow/user/kernels/fused_self_attention_query_mul_key_and_value_kernel.cu index 0243ac36ec7..ea49e053512 100644 --- a/oneflow/user/kernels/fused_self_attention_query_mul_key_and_value_kernel.cu +++ b/oneflow/user/kernels/fused_self_attention_query_mul_key_and_value_kernel.cu @@ -266,9 +266,9 @@ class FusedSelfAttentionQueryMulKeyAndValueGradGpuKernel final : public user_op: }; size_t InferTmpBufferSize(user_op::InferContext* ctx) { - const Shape* value_shape = ctx->OutputShape("value", 0); + const Shape& value_shape = ctx->OutputShape("value", 0); DataType value_dtype = *ctx->OutputDType("value", 0); - return value_shape->elem_cnt() * GetSizeOfDataType(value_dtype); + return value_shape.elem_cnt() * GetSizeOfDataType(value_dtype); } size_t InferGradTmpBufferSize(user_op::InferContext* ctx) { diff --git a/oneflow/user/kernels/generate_random_batch_permutation_indices_kernel.cu b/oneflow/user/kernels/generate_random_batch_permutation_indices_kernel.cu index 97ec84abf6d..8928fc5bd9e 100644 --- a/oneflow/user/kernels/generate_random_batch_permutation_indices_kernel.cu +++ b/oneflow/user/kernels/generate_random_batch_permutation_indices_kernel.cu @@ -119,8 +119,8 @@ REGISTER_USER_KERNEL("generate_random_batch_permutation_indices") .SetCreateFn() .SetIsMatchedHob(user_op::HobDeviceType() == DeviceType::kCUDA) .SetInferTmpSizeFn([](oneflow::user_op::InferContext* ctx) { - const Shape* y_shape = ctx->OutputShape("y", 0); - const int32_t batch_size = y_shape->At(0); + const Shape& y_shape = ctx->OutputShape("y", 0); + const int32_t batch_size = y_shape.At(0); const int32_t random_value_aligned_bytes = GetCudaAlignedSize(batch_size * sizeof(float)); const int32_t sorted_value_aligned_bytes = GetCudaAlignedSize(batch_size * sizeof(float)); diff --git a/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp b/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp index 6148e952101..714c9a5cbd3 100644 --- a/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp +++ b/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp @@ -252,7 +252,7 @@ void NcclLogicalSendRecv::Compute(user_op::KernelComputeContext* ctx, user_op::O } size_t InferTmpBufferSize(user_op::InferContext* ctx) { - const Shape* out_shape = ctx->OutputShape("out", 0); + const Shape& out_shape = ctx->OutputShape("out", 0); const user_op::TensorDesc* logical_in_tensor = ctx->LogicalTensorDesc4ArgNameAndIndex("in", 0); const Shape& logical_shape = logical_in_tensor->shape(); const DataType data_type = logical_in_tensor->data_type(); @@ -278,7 +278,7 @@ size_t InferTmpBufferSize(user_op::InferContext* ctx) { } if (NdSbpHasPartialParallel(src_nd_sbp)) { // Note: when src_nd_sbp has partial_sum, need a out_size buffer to copy and add to out. - buf_count += out_shape->elem_cnt(); + buf_count += out_shape.elem_cnt(); } return buf_count * GetSizeOfDataType(data_type); } diff --git a/oneflow/user/kernels/nms_kernel.cu b/oneflow/user/kernels/nms_kernel.cu index 8a1f1785e0e..fa3984af8ab 100644 --- a/oneflow/user/kernels/nms_kernel.cu +++ b/oneflow/user/kernels/nms_kernel.cu @@ -132,8 +132,8 @@ class NmsGpuKernel final : public user_op::OpKernel { && (user_op::HobDataType("out", 0) == DataType::kInt8) \ && (user_op::HobDataType("in", 0) == GetDataType::value)) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ - Shape* in_shape = ctx->Shape4ArgNameAndIndex("in", 0); \ - int64_t num_boxes = in_shape->At(0); \ + const Shape& in_shape = ctx->Shape4ArgNameAndIndex("in", 0); \ + int64_t num_boxes = in_shape.At(0); \ int64_t blocks = CeilDiv(num_boxes, kBlockSize); \ return num_boxes * blocks * sizeof(int64_t); \ }); diff --git a/oneflow/user/kernels/stateful_opkernel.cpp b/oneflow/user/kernels/stateful_opkernel.cpp index b4dc40e3d0b..621edbe67b5 100644 --- a/oneflow/user/kernels/stateful_opkernel.cpp +++ b/oneflow/user/kernels/stateful_opkernel.cpp @@ -174,26 +174,42 @@ class UserOpInferContextHelper final { const Shape& InputShape(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { - return NonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index)->shape(); + return Shape4ArgNameAndIndex(call_ctx, arg_name, index); } - Shape* OutputShape(eager::CallContext* call_ctx, const std::string& arg_name, - int32_t index) const { + const Shape& OutputShape(eager::CallContext* call_ctx, const std::string& arg_name, + int32_t index) const { return Shape4ArgNameAndIndex(call_ctx, arg_name, index); } - Shape* Shape4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, - int32_t index) const { + Shape* MutOutputShape(eager::CallContext* call_ctx, const std::string& arg_name, + int32_t index) const { + return MutShape4ArgNameAndIndex(call_ctx, arg_name, index); + } + const Shape& Shape4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, + int32_t index) const { + return NonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index)->shape(); + } + Shape* MutShape4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, + int32_t index) const { return NonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index)->mut_shape(); } const Stride& InputStride(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { return NonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index)->stride(); } - Stride* OutputStride(eager::CallContext* call_ctx, const std::string& arg_name, - int32_t index) const { - return Stride4ArgNameAndIndex(call_ctx, arg_name, index); + const Stride& OutputStride(eager::CallContext* call_ctx, const std::string& arg_name, + int32_t index) const { + return NonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index)->stride(); } - Stride* Stride4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, - int32_t index) const { + Stride* MutOutputStride(eager::CallContext* call_ctx, const std::string& arg_name, + int32_t index) const { + return MutStride4ArgNameAndIndex(call_ctx, arg_name, index); + } + const Stride& Stride4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, + int32_t index) const { + return NonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index)->stride(); + } + Stride* MutStride4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, + int32_t index) const { return NonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index)->mut_stride(); } const DataType& InputDType(eager::CallContext* call_ctx, const std::string& arg_name, @@ -317,21 +333,33 @@ class UserOpInferContext : public user_op::InferContext { const Shape& InputShape(const std::string& arg_name, int32_t index) const override { return helper_->InputShape(call_ctx_, arg_name, index); } - Shape* OutputShape(const std::string& arg_name, int32_t index) override { + const Shape& OutputShape(const std::string& arg_name, int32_t index) const override { return helper_->OutputShape(call_ctx_, arg_name, index); } - Shape* Shape4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { + Shape* MutOutputShape(const std::string& arg_name, int32_t index) override { + return helper_->MutOutputShape(call_ctx_, arg_name, index); + } + const Shape& Shape4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { return helper_->Shape4ArgNameAndIndex(call_ctx_, arg_name, index); } + Shape* MutShape4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { + return helper_->MutShape4ArgNameAndIndex(call_ctx_, arg_name, index); + } const Stride& InputStride(const std::string& arg_name, int32_t index) const override { return helper_->InputStride(call_ctx_, arg_name, index); } - Stride* OutputStride(const std::string& arg_name, int32_t index) override { - return helper_->OutputStride(call_ctx_, arg_name, index); + const Stride& OutputStride(const std::string& arg_name, int32_t index) const override { + return helper_->InputStride(call_ctx_, arg_name, index); + } + Stride* MutOutputStride(const std::string& arg_name, int32_t index) override { + return helper_->MutOutputStride(call_ctx_, arg_name, index); } - Stride* Stride4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { + const Stride& Stride4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { return helper_->Stride4ArgNameAndIndex(call_ctx_, arg_name, index); } + Stride* MutStride4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { + return helper_->MutStride4ArgNameAndIndex(call_ctx_, arg_name, index); + } const DataType& InputDType(const std::string& arg_name, int32_t index) const override { return helper_->InputDType(call_ctx_, arg_name, index); } diff --git a/oneflow/user/kernels/two_stage_reduce_kernel.cpp b/oneflow/user/kernels/two_stage_reduce_kernel.cpp index c76eaa9749d..429b0bd0ddf 100644 --- a/oneflow/user/kernels/two_stage_reduce_kernel.cpp +++ b/oneflow/user/kernels/two_stage_reduce_kernel.cpp @@ -127,9 +127,9 @@ template user_op::InferTmpSizeFn GenDeviceStageGradInferTmpSizeFn() { return [](user_op::InferContext* ctx) { const Shape& out_diff_shape = ctx->InputShape("out_diff", 0); - const Shape* in_diff_shape = ctx->OutputShape("in_diff", 0); + const Shape& in_diff_shape = ctx->OutputShape("in_diff", 0); const size_t tmp_bytes = GetCudaAlignedSize(out_diff_shape.elem_cnt() * sizeof(T)); - const size_t broadcasted_tmp_bytes = GetCudaAlignedSize(in_diff_shape->elem_cnt() * sizeof(T)); + const size_t broadcasted_tmp_bytes = GetCudaAlignedSize(in_diff_shape.elem_cnt() * sizeof(T)); return tmp_bytes + broadcasted_tmp_bytes; }; } @@ -259,7 +259,7 @@ user_op::InferTmpSizeFn GenGlobalStageGradInferTmpSizeFn() { return [](user_op::InferContext* ctx) { const Shape& device_count_shape = ctx->InputShape("device_count", 0); const Shape& out_diff_shape = ctx->InputShape("out_diff", 0); - const Shape* in_diff_shape = ctx->OutputShape("in_diff", 0); + const Shape& in_diff_shape = ctx->OutputShape("in_diff", 0); const size_t device_count_with_mask_bytes = GetCudaAlignedSize(device_count_shape.elem_cnt() * sizeof(int32_t)); const size_t global_count_bytes = @@ -268,7 +268,7 @@ user_op::InferTmpSizeFn GenGlobalStageGradInferTmpSizeFn() { GetCudaAlignedSize(device_count_shape.elem_cnt() * sizeof(int32_t)); const size_t divided_buf_bytes = GetCudaAlignedSize(out_diff_shape.elem_cnt() * sizeof(T)); const size_t broadcasted_divided_buf_bytes = - GetCudaAlignedSize(in_diff_shape->elem_cnt() * sizeof(T)); + GetCudaAlignedSize(in_diff_shape.elem_cnt() * sizeof(T)); const size_t total_bytes = device_count_with_mask_bytes + global_count_bytes + reduce_sum_tmp_bytes + divided_buf_bytes + broadcasted_divided_buf_bytes; diff --git a/oneflow/user/kernels/unsorted_segment_sum_kernel.cpp b/oneflow/user/kernels/unsorted_segment_sum_kernel.cpp index bcd7b1c5364..f18bd44f99a 100644 --- a/oneflow/user/kernels/unsorted_segment_sum_kernel.cpp +++ b/oneflow/user/kernels/unsorted_segment_sum_kernel.cpp @@ -193,8 +193,8 @@ class UnsortedSegmentSumHalfKernel final : public user_op::OpKernel { && (user_op::HobDataType("segment_ids", 0) == OF_PP_PAIR_SECOND(segment_ids_type)) \ && (user_op::HobDataType("out", 0) == OF_PP_PAIR_SECOND(out_type))) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ - const Shape* out_shape = ctx->OutputShape("out", 0); \ - return GetCudaAlignedSize(out_shape->elem_cnt() * sizeof(float)); \ + const Shape& out_shape = ctx->OutputShape("out", 0); \ + return GetCudaAlignedSize(out_shape.elem_cnt() * sizeof(float)); \ }); #define REGISTER_UNSORTED_SEGMENT_SUM_HALF_KERNEL_CASE(out_type, segment_ids_type) \ diff --git a/oneflow/user/kernels/where_kernel.cpp b/oneflow/user/kernels/where_kernel.cpp index ee9265f6cf5..0797dd151f7 100644 --- a/oneflow/user/kernels/where_kernel.cpp +++ b/oneflow/user/kernels/where_kernel.cpp @@ -191,13 +191,13 @@ class WhereScalarXYKernel final : public user_op::OpKernel { && (user_op::HobDataType("condition", 0) == OF_PP_PAIR_SECOND(ctype_pair)) \ && (user_op::HobDataType("out", 0) == OF_PP_PAIR_SECOND(dtype_pair))) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ - Shape* out_shape = ctx->OutputShape("out", 0); \ + const Shape& out_shape = ctx->OutputShape("out", 0); \ const size_t x_bytes = \ - GetCudaAlignedSize(out_shape->elem_cnt() * sizeof(OF_PP_PAIR_FIRST(dtype_pair))); \ + GetCudaAlignedSize(out_shape.elem_cnt() * sizeof(OF_PP_PAIR_FIRST(dtype_pair))); \ const size_t y_bytes = \ - GetCudaAlignedSize(out_shape->elem_cnt() * sizeof(OF_PP_PAIR_FIRST(dtype_pair))); \ + GetCudaAlignedSize(out_shape.elem_cnt() * sizeof(OF_PP_PAIR_FIRST(dtype_pair))); \ const size_t cond_bytes = \ - GetCudaAlignedSize(out_shape->elem_cnt() * sizeof(OF_PP_PAIR_FIRST(ctype_pair))); \ + GetCudaAlignedSize(out_shape.elem_cnt() * sizeof(OF_PP_PAIR_FIRST(ctype_pair))); \ return x_bytes + y_bytes + cond_bytes; \ }); @@ -209,11 +209,11 @@ class WhereScalarXYKernel final : public user_op::OpKernel { && (user_op::HobDataType("condition", 0) == OF_PP_PAIR_SECOND(ctype_pair)) \ && (user_op::HobDataType("out", 0) == OF_PP_PAIR_SECOND(dtype_pair))) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ - Shape* out_shape = ctx->OutputShape("out", 0); \ + const Shape& out_shape = ctx->OutputShape("out", 0); \ const size_t y_bytes = \ - GetCudaAlignedSize(out_shape->elem_cnt() * sizeof(OF_PP_PAIR_FIRST(dtype_pair))); \ + GetCudaAlignedSize(out_shape.elem_cnt() * sizeof(OF_PP_PAIR_FIRST(dtype_pair))); \ const size_t cond_bytes = \ - GetCudaAlignedSize(out_shape->elem_cnt() * sizeof(OF_PP_PAIR_FIRST(ctype_pair))); \ + GetCudaAlignedSize(out_shape.elem_cnt() * sizeof(OF_PP_PAIR_FIRST(ctype_pair))); \ return y_bytes + cond_bytes; \ }); @@ -225,11 +225,11 @@ class WhereScalarXYKernel final : public user_op::OpKernel { && (user_op::HobDataType("condition", 0) == OF_PP_PAIR_SECOND(ctype_pair)) \ && (user_op::HobDataType("out", 0) == OF_PP_PAIR_SECOND(dtype_pair))) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ - Shape* out_shape = ctx->OutputShape("out", 0); \ + const Shape& out_shape = ctx->OutputShape("out", 0); \ const size_t x_bytes = \ - GetCudaAlignedSize(out_shape->elem_cnt() * sizeof(OF_PP_PAIR_FIRST(dtype_pair))); \ + GetCudaAlignedSize(out_shape.elem_cnt() * sizeof(OF_PP_PAIR_FIRST(dtype_pair))); \ const size_t cond_bytes = \ - GetCudaAlignedSize(out_shape->elem_cnt() * sizeof(OF_PP_PAIR_FIRST(ctype_pair))); \ + GetCudaAlignedSize(out_shape.elem_cnt() * sizeof(OF_PP_PAIR_FIRST(ctype_pair))); \ return x_bytes + cond_bytes; \ }); diff --git a/oneflow/user/ops/acc_op.cpp b/oneflow/user/ops/acc_op.cpp index 92df9df8f8e..f645c023711 100644 --- a/oneflow/user/ops/acc_op.cpp +++ b/oneflow/user/ops/acc_op.cpp @@ -30,7 +30,7 @@ namespace oneflow { return Maybe::Ok(); } /*static*/ Maybe AccOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/adaptive_pool_op.cpp b/oneflow/user/ops/adaptive_pool_op.cpp index 935e644ea83..35cf44f0c1d 100644 --- a/oneflow/user/ops/adaptive_pool_op.cpp +++ b/oneflow/user/ops/adaptive_pool_op.cpp @@ -31,12 +31,12 @@ Maybe InferFWTensorDesc(user_op::InferContext* ctx) { out_shape[i] = output_size.size() > i - 2 ? output_size[i - 2] : output_size[0]; } - *ctx->OutputShape("y", 0) = Shape(out_shape); + *ctx->MutOutputShape("y", 0) = Shape(out_shape); return Maybe::Ok(); } Maybe InferBWTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("dx", 0) = ctx->InputShape("x", 0); + *ctx->MutOutputShape("dx", 0) = ctx->InputShape("x", 0); *ctx->OutputIsDynamic("dx", 0) = ctx->InputIsDynamic("x", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/arange_op.cpp b/oneflow/user/ops/arange_op.cpp index 73585347376..36a3c954c11 100644 --- a/oneflow/user/ops/arange_op.cpp +++ b/oneflow/user/ops/arange_op.cpp @@ -21,7 +21,7 @@ limitations under the License. namespace oneflow { /* static */ Maybe ArangeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - Shape* out_shape = ctx->OutputShape("out", 0); + Shape* out_shape = ctx->MutOutputShape("out", 0); DataType dtype = ctx->Attr("dtype"); int64_t range_elem_cnt = 0; if (IsIntegralDataType(dtype)) { @@ -88,7 +88,7 @@ namespace oneflow { GetTensorSliceView4ParallelId(parallel_hierarchy, nd_sbp, logical_shape, parallel_id); const Shape& physical_shape = tensor_slice_view.shape(); - *ctx->OutputShape("out", 0) = physical_shape; + *ctx->MutOutputShape("out", 0) = physical_shape; return Maybe::Ok(); } diff --git a/oneflow/user/ops/arg_sort_op.cpp b/oneflow/user/ops/arg_sort_op.cpp index e4ca90915ff..55cf61d6f05 100644 --- a/oneflow/user/ops/arg_sort_op.cpp +++ b/oneflow/user/ops/arg_sort_op.cpp @@ -19,7 +19,7 @@ limitations under the License. namespace oneflow { /* static */ Maybe ArgSortOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/argmax_op.cpp b/oneflow/user/ops/argmax_op.cpp index 58c6581eb29..17cb35709bf 100644 --- a/oneflow/user/ops/argmax_op.cpp +++ b/oneflow/user/ops/argmax_op.cpp @@ -21,7 +21,7 @@ namespace oneflow { /* static */ Maybe ArgmaxOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { auto dim_vec = ctx->InputShape("in", 0).dim_vec(); dim_vec.pop_back(); - *ctx->OutputShape("out", 0) = Shape(std::move(dim_vec)); + *ctx->MutOutputShape("out", 0) = Shape(std::move(dim_vec)); return Maybe::Ok(); } diff --git a/oneflow/user/ops/avg_pool_op.cpp b/oneflow/user/ops/avg_pool_op.cpp index e6d1521707d..23b4f8377ad 100644 --- a/oneflow/user/ops/avg_pool_op.cpp +++ b/oneflow/user/ops/avg_pool_op.cpp @@ -27,7 +27,7 @@ typedef std::function(const user_op::UserOpWrapper& op, user_op::Add TensorDescInferFn AvgPoolMakeForwardTensorDescInferFn(const int32_t dim) { return [dim](user_op::InferContext* ctx) -> Maybe { - const Shape* x_shape = ctx->Shape4ArgNameAndIndex("x", 0); + const Shape& x_shape = ctx->Shape4ArgNameAndIndex("x", 0); const std::string& data_format = ctx->Attr("data_format"); const std::vector& padding = ctx->Attr>("padding"); const std::vector& kernel_size = ctx->Attr>("kernel_size"); @@ -53,7 +53,7 @@ TensorDescInferFn AvgPoolMakeForwardTensorDescInferFn(const int32_t dim) { << "pad should be smaller than half of kernel size"; } - const AvgPoolParams3D params_3d(dim, *x_shape, data_format, padding, kernel_size, stride, + const AvgPoolParams3D params_3d(dim, x_shape, data_format, padding, kernel_size, stride, ceil_mode, count_include_pad, divisor_override); user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); *y_desc = ctx->InputTensorDesc("x", 0); diff --git a/oneflow/user/ops/bias_add_op.cpp b/oneflow/user/ops/bias_add_op.cpp index 77dfff37837..963ac103951 100644 --- a/oneflow/user/ops/bias_add_op.cpp +++ b/oneflow/user/ops/bias_add_op.cpp @@ -35,7 +35,7 @@ namespace oneflow { << Error::RuntimeError() << "The size of tensor " << a_tensor_desc.shape().ToString() << " must match the size of tensor " << b_tensor_desc.shape().ToString() << " at dimension " << bias_add_axis; - *ctx->OutputShape("out", 0) = ctx->InputShape("a", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("a", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("a", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/broadcast_div_grad_op.cpp b/oneflow/user/ops/broadcast_div_grad_op.cpp index c59b2436997..791fa84ad1b 100644 --- a/oneflow/user/ops/broadcast_div_grad_op.cpp +++ b/oneflow/user/ops/broadcast_div_grad_op.cpp @@ -19,7 +19,7 @@ limitations under the License. namespace oneflow { /* static */ Maybe BroadcastDivGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("dy", 0) = ctx->InputShape("y", 0); + *ctx->MutOutputShape("dy", 0) = ctx->InputShape("y", 0); *ctx->OutputIsDynamic("dy", 0) = ctx->InputIsDynamic("y", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/broadcast_like_op.cpp b/oneflow/user/ops/broadcast_like_op.cpp index 1478378ea7f..1e6f1456cac 100644 --- a/oneflow/user/ops/broadcast_like_op.cpp +++ b/oneflow/user/ops/broadcast_like_op.cpp @@ -78,8 +78,8 @@ Maybe InferTensorDesc(user_op::InferContext* ctx) { CHECK_OR_RETURN(!broadcast_axes.empty()); const Shape& in_shape = ctx->InputShape("x", 0); const Shape& like_shape = ctx->InputShape("like", 0); - Shape* out_shape = ctx->OutputShape("y", 0); - Stride* out_stride = ctx->OutputStride("y", 0); + Shape* out_shape = ctx->MutOutputShape("y", 0); + Stride* out_stride = ctx->MutOutputStride("y", 0); const AxisVector axis_vec = {broadcast_axes.begin(), broadcast_axes.end()}; CHECK_OR_RETURN(IsAxesLegal(axis_vec, like_shape, in_shape)); *out_shape = like_shape; diff --git a/oneflow/user/ops/broadcast_pow_grad_op.cpp b/oneflow/user/ops/broadcast_pow_grad_op.cpp index 21fa575b03b..ab23165638a 100644 --- a/oneflow/user/ops/broadcast_pow_grad_op.cpp +++ b/oneflow/user/ops/broadcast_pow_grad_op.cpp @@ -19,7 +19,7 @@ limitations under the License. namespace oneflow { /* static */ Maybe BroadcastPowXGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("dx", 0) = ctx->InputShape("x", 0); + *ctx->MutOutputShape("dx", 0) = ctx->InputShape("x", 0); *ctx->OutputIsDynamic("dx", 0) = ctx->InputIsDynamic("x", 0); return Maybe::Ok(); } @@ -76,7 +76,7 @@ namespace oneflow { } /* static */ Maybe BroadcastPowYGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("dy", 0) = ctx->InputShape("y", 0); + *ctx->MutOutputShape("dy", 0) = ctx->InputShape("y", 0); *ctx->OutputIsDynamic("dy", 0) = ctx->InputIsDynamic("y", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/buffer_op.cpp b/oneflow/user/ops/buffer_op.cpp index eb8abde1ee6..86f8cd1e79e 100644 --- a/oneflow/user/ops/buffer_op.cpp +++ b/oneflow/user/ops/buffer_op.cpp @@ -19,7 +19,7 @@ limitations under the License. namespace oneflow { /* static */ Maybe IdentityBufferOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/cast_like_op.cpp b/oneflow/user/ops/cast_like_op.cpp index c4d41a00be8..77cc334b087 100644 --- a/oneflow/user/ops/cast_like_op.cpp +++ b/oneflow/user/ops/cast_like_op.cpp @@ -19,7 +19,7 @@ limitations under the License. namespace oneflow { /* static */ Maybe CastLikeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/cast_to_tick_op.cpp b/oneflow/user/ops/cast_to_tick_op.cpp index bb76f5887e6..576ca9fc220 100644 --- a/oneflow/user/ops/cast_to_tick_op.cpp +++ b/oneflow/user/ops/cast_to_tick_op.cpp @@ -20,7 +20,7 @@ limitations under the License. namespace oneflow { /* static */ Maybe CastToTickOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - Shape* out_shape = ctx->OutputShape("out", 0); + Shape* out_shape = ctx->MutOutputShape("out", 0); *out_shape = Shape({1}); return Maybe::Ok(); } diff --git a/oneflow/user/ops/categorical_ordinal_encode_op.cpp b/oneflow/user/ops/categorical_ordinal_encode_op.cpp index ca2b4533826..e478d910532 100644 --- a/oneflow/user/ops/categorical_ordinal_encode_op.cpp +++ b/oneflow/user/ops/categorical_ordinal_encode_op.cpp @@ -26,7 +26,7 @@ namespace oneflow { const Shape& size_shape = ctx->InputShape("size", 0); CHECK_EQ_OR_RETURN(size_shape.NumAxes(), 1); CHECK_EQ_OR_RETURN(size_shape.elem_cnt(), 1); - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); return Maybe::Ok(); } @@ -39,7 +39,7 @@ namespace oneflow { const Shape& size_shape = ctx->InputShape("size", 0); CHECK_EQ_OR_RETURN(size_shape.NumAxes(), 1); CHECK_EQ_OR_RETURN(size_shape.elem_cnt(), 1); - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/celu_op.cpp b/oneflow/user/ops/celu_op.cpp index 60d48152434..039124a0f6d 100644 --- a/oneflow/user/ops/celu_op.cpp +++ b/oneflow/user/ops/celu_op.cpp @@ -19,7 +19,7 @@ limitations under the License. namespace oneflow { /* static */ Maybe CeluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); return Maybe::Ok(); } @@ -43,7 +43,7 @@ namespace oneflow { /* static */ Maybe CeluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); + Shape* dx_shape = ctx->MutOutputShape("dx", 0); CHECK_OR_RETURN(dy_shape == x_shape); *dx_shape = dy_shape; return Maybe::Ok(); diff --git a/oneflow/user/ops/clip_by_value_op.cpp b/oneflow/user/ops/clip_by_value_op.cpp index f216e077816..63363bbb153 100644 --- a/oneflow/user/ops/clip_by_value_op.cpp +++ b/oneflow/user/ops/clip_by_value_op.cpp @@ -21,7 +21,7 @@ namespace oneflow { namespace { Maybe InferClipTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("y", 0) = ctx->InputShape("x", 0); + *ctx->MutOutputShape("y", 0) = ctx->InputShape("x", 0); return Maybe::Ok(); } @@ -34,7 +34,7 @@ Maybe GetClipSbpSignature(user_op::SbpContext* ctx) { } Maybe InferClipGradTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("dx", 0) = ctx->InputShape("x", 0); + *ctx->MutOutputShape("dx", 0) = ctx->InputShape("x", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/combined_margin_loss_op.cpp b/oneflow/user/ops/combined_margin_loss_op.cpp index 72854a53928..65b462ac1b0 100644 --- a/oneflow/user/ops/combined_margin_loss_op.cpp +++ b/oneflow/user/ops/combined_margin_loss_op.cpp @@ -24,7 +24,7 @@ namespace oneflow { user_op::TensorDesc* theta = ctx->OutputTensorDesc("theta", 0); CHECK_EQ_OR_RETURN(label.shape().At(0), x.shape().At(0)); CHECK_GE_OR_RETURN(x.shape().NumAxes(), 2); - *ctx->OutputShape("y", 0) = ctx->InputShape("x", 0); + *ctx->MutOutputShape("y", 0) = ctx->InputShape("x", 0); *ctx->IsDynamic4ArgNameAndIndex("y", 0) = ctx->InputIsDynamic("x", 0); *theta->mut_is_dynamic() = x.is_dynamic(); *theta->mut_shape() = label.shape(); @@ -72,7 +72,7 @@ namespace oneflow { CHECK_EQ_OR_RETURN(label.shape().At(0), dy.shape().At(0)); CHECK_EQ_OR_RETURN(label.shape().At(0), theta.shape().At(0)); CHECK_GE_OR_RETURN(dy.shape().NumAxes(), 2); - *ctx->OutputShape("dx", 0) = ctx->InputShape("dy", 0); + *ctx->MutOutputShape("dx", 0) = ctx->InputShape("dy", 0); *ctx->IsDynamic4ArgNameAndIndex("dx", 0) = ctx->InputIsDynamic("dy", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/constant_op.cpp b/oneflow/user/ops/constant_op.cpp index 62d9bdcc050..4a14f638b43 100644 --- a/oneflow/user/ops/constant_op.cpp +++ b/oneflow/user/ops/constant_op.cpp @@ -20,7 +20,7 @@ limitations under the License. namespace oneflow { /* static */ Maybe ConstantOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = Shape(ctx->Attr("shape").dim_vec()); + *ctx->MutOutputShape("out", 0) = Shape(ctx->Attr("shape").dim_vec()); return Maybe::Ok(); } @@ -33,7 +33,7 @@ namespace oneflow { GetTensorSliceView4ParallelId(parallel_hierarchy, nd_sbp, logical_shape, parallel_id); const Shape& physical_shape = tensor_slice_view.shape(); - *ctx->OutputShape("out", 0) = physical_shape; + *ctx->MutOutputShape("out", 0) = physical_shape; return Maybe::Ok(); } diff --git a/oneflow/user/ops/conv_op.cpp b/oneflow/user/ops/conv_op.cpp index 64940f4d2da..ce753a087f3 100644 --- a/oneflow/user/ops/conv_op.cpp +++ b/oneflow/user/ops/conv_op.cpp @@ -308,7 +308,7 @@ Maybe GenerateBackwardOpConf4Conv(const user_op::UserOpWrapper& op, user_o const user_op::TensorDesc& add_to_output = ctx->InputTensorDesc("_add_to_output", 0); CHECK_EQ_OR_RETURN(add_to_output.shape(), x_like.shape()); } - *ctx->OutputShape("dx", 0) = ctx->InputShape("x_like", 0); + *ctx->MutOutputShape("dx", 0) = ctx->InputShape("x_like", 0); *ctx->OutputIsDynamic("dx", 0) = ctx->InputIsDynamic("x_like", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/copy_op.cpp b/oneflow/user/ops/copy_op.cpp index 6b7d5f994f2..f283e7c716a 100644 --- a/oneflow/user/ops/copy_op.cpp +++ b/oneflow/user/ops/copy_op.cpp @@ -42,8 +42,8 @@ Maybe> MakeCopyStream(const Symbol& in_device, } // namespace /* static */ Maybe CopyOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputStride("out", 0) = ctx->InputStride("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputStride("out", 0) = ctx->InputStride("in", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/ctc_loss_op.cpp b/oneflow/user/ops/ctc_loss_op.cpp index b8dee1ad9cc..3b8e466c923 100644 --- a/oneflow/user/ops/ctc_loss_op.cpp +++ b/oneflow/user/ops/ctc_loss_op.cpp @@ -34,8 +34,8 @@ namespace oneflow { CHECK_GE_OR_RETURN(ctx->Attr("blank"), 0); CHECK_LT_OR_RETURN(ctx->Attr("blank"), log_probs.shape().At(2)); - *ctx->OutputShape("loss", 0) = Shape({batch_size}); - *ctx->OutputShape("alpha", 0) = + *ctx->MutOutputShape("loss", 0) = Shape({batch_size}); + *ctx->MutOutputShape("alpha", 0) = Shape({batch_size, log_probs.shape().At(0), 2 * max_target_length + 1}); return Maybe::Ok(); } @@ -78,7 +78,7 @@ namespace oneflow { CHECK_GE_OR_RETURN(ctx->Attr("blank"), 0); CHECK_LT_OR_RETURN(ctx->Attr("blank"), log_probs.shape().At(2)); - *ctx->OutputShape("grad", 0) = log_probs.shape(); + *ctx->MutOutputShape("grad", 0) = log_probs.shape(); return Maybe::Ok(); } @@ -110,8 +110,8 @@ namespace oneflow { const user_op::TensorDesc& input_lengths = ctx->InputTensorDesc("input_lengths", 0); const int64_t batch_size = log_probs.shape().At(1); CHECK_EQ_OR_RETURN(batch_size, input_lengths.shape().At(0)); - *ctx->OutputShape("decoded", 0) = Shape({batch_size, log_probs.shape().At(0)}); - *ctx->OutputShape("neg_sum_logits", 0) = Shape({batch_size, 1}); + *ctx->MutOutputShape("decoded", 0) = Shape({batch_size, log_probs.shape().At(0)}); + *ctx->MutOutputShape("neg_sum_logits", 0) = Shape({batch_size, 1}); return Maybe::Ok(); } diff --git a/oneflow/user/ops/cublas_bias_add_relu_matmul_grad_op.cpp b/oneflow/user/ops/cublas_bias_add_relu_matmul_grad_op.cpp index ae09393bf85..0114b96336a 100644 --- a/oneflow/user/ops/cublas_bias_add_relu_matmul_grad_op.cpp +++ b/oneflow/user/ops/cublas_bias_add_relu_matmul_grad_op.cpp @@ -28,8 +28,8 @@ Maybe InferTensorDesc4FusedMatmulBackward(user_op::InferContext* ctx) { const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); const int64_t bias_size = weight_desc.shape().At(1); Shape d_grad_shape({dy_desc.shape().At(0), weight_desc.shape().At(1)}); - *ctx->OutputShape("d_grad", 0) = d_grad_shape; - *ctx->OutputShape("d_bias", 0) = Shape({bias_size}); + *ctx->MutOutputShape("d_grad", 0) = d_grad_shape; + *ctx->MutOutputShape("d_bias", 0) = Shape({bias_size}); return Maybe::Ok(); } diff --git a/oneflow/user/ops/cublas_fused_matmul_bias_add_grad_op.cpp b/oneflow/user/ops/cublas_fused_matmul_bias_add_grad_op.cpp index 8ae2e512d62..58e9b5e6912 100644 --- a/oneflow/user/ops/cublas_fused_matmul_bias_add_grad_op.cpp +++ b/oneflow/user/ops/cublas_fused_matmul_bias_add_grad_op.cpp @@ -36,8 +36,8 @@ Maybe InferTensorDesc4MatmulBiasAddBackward(user_op::InferContext* ctx) { const int64_t bias_size = dy_desc.shape().At(1); Shape w_grad_shape({dy_desc.shape().At(1), x_desc.shape().At(1)}); - *ctx->OutputShape("w_grad", 0) = w_grad_shape; - *ctx->OutputShape("b_grad", 0) = Shape({bias_size}); + *ctx->MutOutputShape("w_grad", 0) = w_grad_shape; + *ctx->MutOutputShape("b_grad", 0) = Shape({bias_size}); return Maybe::Ok(); } diff --git a/oneflow/user/ops/cublas_fused_mlp_grad_op.cpp b/oneflow/user/ops/cublas_fused_mlp_grad_op.cpp index cf4fd9d3bcd..f21853568a1 100644 --- a/oneflow/user/ops/cublas_fused_mlp_grad_op.cpp +++ b/oneflow/user/ops/cublas_fused_mlp_grad_op.cpp @@ -25,10 +25,10 @@ Maybe InferTensorDesc4FusedMatmulBackward(user_op::InferContext* ctx) { const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); for (int idx = weight_num - 1; idx >= 0; idx--) { const user_op::TensorDesc& weight_desc = ctx->InputTensorDesc("weights", idx); - *ctx->OutputShape("d_weights", idx) = weight_desc.shape(); - *ctx->OutputShape("d_biases", idx) = Shape({weight_desc.shape().At(0)}); + *ctx->MutOutputShape("d_weights", idx) = weight_desc.shape(); + *ctx->MutOutputShape("d_biases", idx) = Shape({weight_desc.shape().At(0)}); } - *ctx->OutputShape("d_x", 0) = x_desc.shape(); + *ctx->MutOutputShape("d_x", 0) = x_desc.shape(); return Maybe::Ok(); } diff --git a/oneflow/user/ops/cublas_fused_mlp_op.cpp b/oneflow/user/ops/cublas_fused_mlp_op.cpp index 9bc5d9f1b57..9369a0e303c 100644 --- a/oneflow/user/ops/cublas_fused_mlp_op.cpp +++ b/oneflow/user/ops/cublas_fused_mlp_op.cpp @@ -65,12 +65,12 @@ Maybe InferTensorDesc4FusedMatmul(user_op::InferContext* ctx) { // Set Middle result shape. long cublas_aligned_aux_ld = AlignReluAuxLd(cublas_aux_ld); int64_t aux_size = cublas_aligned_aux_ld / 32; // Cause we use int32_t as dtype - *ctx->OutputShape("cublas_aux", idx) = Shape({m, aux_size}); - *ctx->OutputShape("hidden", idx) = Shape({m, n}); + *ctx->MutOutputShape("cublas_aux", idx) = Shape({m, aux_size}); + *ctx->MutOutputShape("hidden", idx) = Shape({m, n}); // Set for next layer. k = n; } - *ctx->OutputShape("out", 0) = {m, n}; + *ctx->MutOutputShape("out", 0) = {m, n}; return Maybe::Ok(); } diff --git a/oneflow/user/ops/cum_ops.cpp b/oneflow/user/ops/cum_ops.cpp index 265a201119d..9ee5b5c123a 100644 --- a/oneflow/user/ops/cum_ops.cpp +++ b/oneflow/user/ops/cum_ops.cpp @@ -19,7 +19,7 @@ limitations under the License. namespace oneflow { Maybe CumsumOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("y", 0) = ctx->InputShape("x", 0); + *ctx->MutOutputShape("y", 0) = ctx->InputShape("x", 0); return Maybe::Ok(); } @@ -73,7 +73,7 @@ REGISTER_USER_OP_GRAD("cumsum").SetGenBackwardOpConfFn( }); Maybe CumProdOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("y", 0) = ctx->InputShape("x", 0); + *ctx->MutOutputShape("y", 0) = ctx->InputShape("x", 0); return Maybe::Ok(); } @@ -96,7 +96,7 @@ Maybe CumProdOp::InferDataType(user_op::InferContext* ctx) { } Maybe CumProdGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("dx", 0) = ctx->InputShape("dy", 0); + *ctx->MutOutputShape("dx", 0) = ctx->InputShape("dy", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/data_shuffle_op.cpp b/oneflow/user/ops/data_shuffle_op.cpp index e8e3ebfa9fa..3f0a4b9abb9 100644 --- a/oneflow/user/ops/data_shuffle_op.cpp +++ b/oneflow/user/ops/data_shuffle_op.cpp @@ -32,10 +32,10 @@ namespace oneflow { CHECK_EQ_OR_RETURN(keys_shape.At(1), num_tables) << "keys cols must equal to num_tables"; } } - *ctx->OutputShape("num_unique", 0) = Shape({1}); - *ctx->OutputShape("unique_keys", 0) = Shape({keys_shape.elem_cnt()}); - *ctx->OutputShape("unique_values", 0) = Shape({keys_shape.elem_cnt()}); - *ctx->OutputShape("inverse_indices", 0) = keys_shape; + *ctx->MutOutputShape("num_unique", 0) = Shape({1}); + *ctx->MutOutputShape("unique_keys", 0) = Shape({keys_shape.elem_cnt()}); + *ctx->MutOutputShape("unique_values", 0) = Shape({keys_shape.elem_cnt()}); + *ctx->MutOutputShape("inverse_indices", 0) = keys_shape; return Maybe::Ok(); } @@ -74,12 +74,12 @@ namespace oneflow { } const int64_t num_ids = ids_shape.elem_cnt(); const int64_t parallel_num = ctx->parallel_num(); - *ctx->OutputShape("num_unique_matrix", 0) = Shape({parallel_num * parallel_num}); - *ctx->OutputShape("inverse_unique_partition_indices", 0) = ids_shape; - *ctx->OutputShape("cur_rank_num_unique", 0) = Shape({1}); - *ctx->OutputShape("cur_rank_unique_ids", 0) = Shape({num_ids * parallel_num}); - *ctx->OutputShape("cur_rank_inverse_indices", 0) = Shape({num_ids * parallel_num}); - *ctx->OutputShape("cur_rank_unique_table_ids", 0) = Shape({num_ids * parallel_num}); + *ctx->MutOutputShape("num_unique_matrix", 0) = Shape({parallel_num * parallel_num}); + *ctx->MutOutputShape("inverse_unique_partition_indices", 0) = ids_shape; + *ctx->MutOutputShape("cur_rank_num_unique", 0) = Shape({1}); + *ctx->MutOutputShape("cur_rank_unique_ids", 0) = Shape({num_ids * parallel_num}); + *ctx->MutOutputShape("cur_rank_inverse_indices", 0) = Shape({num_ids * parallel_num}); + *ctx->MutOutputShape("cur_rank_unique_table_ids", 0) = Shape({num_ids * parallel_num}); return Maybe::Ok(); } @@ -135,7 +135,7 @@ namespace oneflow { CHECK_EQ_OR_RETURN(cur_rank_inverse_indices_shape.elem_cnt(), parallel_num * num_ids); DimVector out_dim_vec = inverse_unique_partition_indices_shape.dim_vec(); out_dim_vec.push_back(embedding_size); - *ctx->OutputShape("embeddings", 0) = Shape(out_dim_vec); + *ctx->MutOutputShape("embeddings", 0) = Shape(out_dim_vec); return Maybe::Ok(); } @@ -179,7 +179,7 @@ namespace oneflow { CHECK_EQ_OR_RETURN(cur_rank_inverse_indices_shape.elem_cnt(), parallel_num * num_ids); DimVector out_dim_vec = cur_rank_inverse_indices_shape.dim_vec(); out_dim_vec.push_back(embedding_size); - *ctx->OutputShape("cur_rank_unique_embedding_grad", 0) = Shape(out_dim_vec); + *ctx->MutOutputShape("cur_rank_unique_embedding_grad", 0) = Shape(out_dim_vec); return Maybe::Ok(); } diff --git a/oneflow/user/ops/distributions/normal_op.cpp b/oneflow/user/ops/distributions/normal_op.cpp index 736a70e5d0b..769ff12dd2e 100644 --- a/oneflow/user/ops/distributions/normal_op.cpp +++ b/oneflow/user/ops/distributions/normal_op.cpp @@ -21,7 +21,7 @@ limitations under the License. namespace oneflow { /* static */ Maybe NormalOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - Shape* out_shape = ctx->OutputShape("out", 0); + Shape* out_shape = ctx->MutOutputShape("out", 0); const Shape& shape = ctx->Attr("shape"); *out_shape = shape; return Maybe::Ok(); @@ -36,7 +36,7 @@ namespace oneflow { GetTensorSliceView4ParallelId(parallel_hierarchy, nd_sbp, logical_shape, parallel_id); const Shape& physical_shape = tensor_slice_view.shape(); - *ctx->OutputShape("out", 0) = physical_shape; + *ctx->MutOutputShape("out", 0) = physical_shape; return Maybe::Ok(); } diff --git a/oneflow/user/ops/distributions/uniform_int_op.cpp b/oneflow/user/ops/distributions/uniform_int_op.cpp index f01bb710f3c..63b0e39d74d 100644 --- a/oneflow/user/ops/distributions/uniform_int_op.cpp +++ b/oneflow/user/ops/distributions/uniform_int_op.cpp @@ -20,7 +20,7 @@ limitations under the License. namespace oneflow { /* static */ Maybe UniformIntOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - Shape* out_shape = ctx->OutputShape("out", 0); + Shape* out_shape = ctx->MutOutputShape("out", 0); const Shape& shape = ctx->Attr("shape"); DimVector dim_vec; if (shape.NumAxes() > 0) { @@ -39,7 +39,7 @@ namespace oneflow { GetTensorSliceView4ParallelId(parallel_hierarchy, nd_sbp, logical_shape, parallel_id); const Shape& physical_shape = tensor_slice_view.shape(); - *ctx->OutputShape("out", 0) = physical_shape; + *ctx->MutOutputShape("out", 0) = physical_shape; return Maybe::Ok(); } diff --git a/oneflow/user/ops/distributions/uniform_op.cpp b/oneflow/user/ops/distributions/uniform_op.cpp index b7d566aac49..3ccb8400fab 100644 --- a/oneflow/user/ops/distributions/uniform_op.cpp +++ b/oneflow/user/ops/distributions/uniform_op.cpp @@ -20,7 +20,7 @@ limitations under the License. namespace oneflow { /* static */ Maybe UniformOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - Shape* out_shape = ctx->OutputShape("out", 0); + Shape* out_shape = ctx->MutOutputShape("out", 0); const Shape& shape = ctx->Attr("shape"); DimVector dim_vec; if (shape.NumAxes() > 0) { @@ -39,7 +39,7 @@ namespace oneflow { GetTensorSliceView4ParallelId(parallel_hierarchy, nd_sbp, logical_shape, parallel_id); const Shape& physical_shape = tensor_slice_view.shape(); - *ctx->OutputShape("out", 0) = physical_shape; + *ctx->MutOutputShape("out", 0) = physical_shape; return Maybe::Ok(); } diff --git a/oneflow/user/ops/dot_op.cpp b/oneflow/user/ops/dot_op.cpp index 080a8cff539..7ea24b0d9f8 100644 --- a/oneflow/user/ops/dot_op.cpp +++ b/oneflow/user/ops/dot_op.cpp @@ -28,7 +28,7 @@ namespace oneflow { CHECK_OR_RETURN(x.shape().NumAxes() == 1) << Error::RuntimeError() << "1D tensors expected, but got " << x.shape().NumAxes() << "D tensors"; - *ctx->OutputShape("out", 0) = Shape({}); + *ctx->MutOutputShape("out", 0) = Shape({}); return Maybe::Ok(); } diff --git a/oneflow/user/ops/dropout_op.cpp b/oneflow/user/ops/dropout_op.cpp index c23d2ef28af..b74deb9ac06 100644 --- a/oneflow/user/ops/dropout_op.cpp +++ b/oneflow/user/ops/dropout_op.cpp @@ -20,8 +20,8 @@ namespace oneflow { /* static */ Maybe DropoutOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& in_shape = ctx->InputShape("in", 0); - *ctx->OutputShape("out", 0) = in_shape; - *ctx->OutputShape("mask", 0) = in_shape; + *ctx->MutOutputShape("out", 0) = in_shape; + *ctx->MutOutputShape("mask", 0) = in_shape; *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } @@ -53,7 +53,7 @@ namespace oneflow { /* static */ Maybe DropoutGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& dy_shape = ctx->InputShape("dy", 0); - *ctx->OutputShape("dx", 0) = dy_shape; + *ctx->MutOutputShape("dx", 0) = dy_shape; *ctx->OutputIsDynamic("dx", 0) = ctx->InputIsDynamic("dy", 0); CHECK_EQ_OR_RETURN(ctx->InputShape("mask", 0), dy_shape); return Maybe::Ok(); @@ -89,7 +89,7 @@ namespace oneflow { } /* static */ Maybe RandomMaskLikeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("like", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("like", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/eager_b_to_s_op.cpp b/oneflow/user/ops/eager_b_to_s_op.cpp index 00cb6aee242..1d415e230f4 100644 --- a/oneflow/user/ops/eager_b_to_s_op.cpp +++ b/oneflow/user/ops/eager_b_to_s_op.cpp @@ -39,7 +39,7 @@ namespace oneflow { int64_t parallel_id = opt_parallel_id->value_or(0); dim_vec[out_split_axis] = bs.At(parallel_id).size(); } - *ctx->OutputShape("out", 0) = Shape(dim_vec); + *ctx->MutOutputShape("out", 0) = Shape(dim_vec); return Maybe::Ok(); } diff --git a/oneflow/user/ops/eager_nccl_ops.cpp b/oneflow/user/ops/eager_nccl_ops.cpp index 5f574a7b1be..8af86554f51 100644 --- a/oneflow/user/ops/eager_nccl_ops.cpp +++ b/oneflow/user/ops/eager_nccl_ops.cpp @@ -24,7 +24,7 @@ limitations under the License. namespace oneflow { /* static */ Maybe EagerNcclAllReduceOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); return Maybe::Ok(); } @@ -48,7 +48,7 @@ namespace oneflow { } /* static */ Maybe EagerNcclBroadcastOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); return Maybe::Ok(); } @@ -96,7 +96,7 @@ namespace oneflow { } /* static */ Maybe EagerNcclReduceOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); return Maybe::Ok(); } @@ -120,14 +120,14 @@ namespace oneflow { /* static */ Maybe EagerNcclReduceScatterOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); return Maybe::Ok(); } /* static */ Maybe EagerNcclReduceScatterOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { const Shape& in_shape = ctx->InputShape("in", 0); - Shape* out_shape = ctx->OutputShape("out", 0); + Shape* out_shape = ctx->MutOutputShape("out", 0); const int64_t& parallel_num = ctx->parallel_ctx().parallel_num(); if (parallel_num > 1) { const Shape& parallel_hierarchy = *ctx->parallel_desc().hierarchy(); @@ -179,7 +179,7 @@ namespace oneflow { } /* static */ Maybe EagerNcclAllGatherOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } @@ -226,7 +226,7 @@ namespace oneflow { } /* static */ Maybe EagerNcclS2sOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/eager_p_to_b_op.cpp b/oneflow/user/ops/eager_p_to_b_op.cpp index f503dfcefd9..e1ad0d5ca3c 100644 --- a/oneflow/user/ops/eager_p_to_b_op.cpp +++ b/oneflow/user/ops/eager_p_to_b_op.cpp @@ -24,7 +24,7 @@ limitations under the License. namespace oneflow { // Can only be called in local /* static */ Maybe EagerPToBOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = Shape(ctx->Attr("shape").dim_vec()); + *ctx->MutOutputShape("out", 0) = Shape(ctx->Attr("shape").dim_vec()); return Maybe::Ok(); } diff --git a/oneflow/user/ops/eager_p_to_s_op.cpp b/oneflow/user/ops/eager_p_to_s_op.cpp index d05bb50df12..1731cf321e2 100644 --- a/oneflow/user/ops/eager_p_to_s_op.cpp +++ b/oneflow/user/ops/eager_p_to_s_op.cpp @@ -38,7 +38,7 @@ namespace oneflow { int64_t parallel_id = opt_parallel_id->value_or(0); dim_vec[out_split_axis] = bs.At(parallel_id).size(); } - *ctx->OutputShape("out", 0) = Shape(dim_vec); + *ctx->MutOutputShape("out", 0) = Shape(dim_vec); return Maybe::Ok(); } diff --git a/oneflow/user/ops/eager_s_to_b_op.cpp b/oneflow/user/ops/eager_s_to_b_op.cpp index e59d98bb520..9c9ff92d53b 100644 --- a/oneflow/user/ops/eager_s_to_b_op.cpp +++ b/oneflow/user/ops/eager_s_to_b_op.cpp @@ -24,7 +24,7 @@ limitations under the License. namespace oneflow { /* static */ Maybe EagerSToBOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = Shape(ctx->Attr("shape").dim_vec()); + *ctx->MutOutputShape("out", 0) = Shape(ctx->Attr("shape").dim_vec()); return Maybe::Ok(); } diff --git a/oneflow/user/ops/eager_s_to_p_op.cpp b/oneflow/user/ops/eager_s_to_p_op.cpp index 711c8d84501..1caa5dfd408 100644 --- a/oneflow/user/ops/eager_s_to_p_op.cpp +++ b/oneflow/user/ops/eager_s_to_p_op.cpp @@ -24,7 +24,7 @@ limitations under the License. namespace oneflow { /* static */ Maybe EagerSToPOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = Shape(ctx->Attr("shape").dim_vec()); + *ctx->MutOutputShape("out", 0) = Shape(ctx->Attr("shape").dim_vec()); return Maybe::Ok(); } diff --git a/oneflow/user/ops/eager_s_to_s_op.cpp b/oneflow/user/ops/eager_s_to_s_op.cpp index f2ec6bc933d..11c36b19649 100644 --- a/oneflow/user/ops/eager_s_to_s_op.cpp +++ b/oneflow/user/ops/eager_s_to_s_op.cpp @@ -38,7 +38,7 @@ namespace oneflow { int64_t parallel_id = opt_parallel_id->value_or(0); dim_vec[out_split_axis] = bs.At(parallel_id).size(); } - *ctx->OutputShape("out", 0) = Shape(dim_vec); + *ctx->MutOutputShape("out", 0) = Shape(dim_vec); return Maybe::Ok(); } diff --git a/oneflow/user/ops/eager_symmetric_s_to_p_op.cpp b/oneflow/user/ops/eager_symmetric_s_to_p_op.cpp index 1767d96e9f4..95a3716d106 100644 --- a/oneflow/user/ops/eager_symmetric_s_to_p_op.cpp +++ b/oneflow/user/ops/eager_symmetric_s_to_p_op.cpp @@ -22,7 +22,7 @@ limitations under the License. namespace oneflow { /* static */ Maybe EagerSymmetricSToPOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/elu_op.cpp b/oneflow/user/ops/elu_op.cpp index 9de85d34655..7d32b87d832 100644 --- a/oneflow/user/ops/elu_op.cpp +++ b/oneflow/user/ops/elu_op.cpp @@ -19,7 +19,7 @@ limitations under the License. namespace oneflow { /* static */ Maybe EluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); return Maybe::Ok(); } @@ -43,7 +43,7 @@ namespace oneflow { /* static */ Maybe EluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); + Shape* dx_shape = ctx->MutOutputShape("dx", 0); CHECK_OR_RETURN(dy_shape == x_shape); *dx_shape = dy_shape; return Maybe::Ok(); diff --git a/oneflow/user/ops/embedding_op.cpp b/oneflow/user/ops/embedding_op.cpp index 5d124cac674..ab3a0960519 100644 --- a/oneflow/user/ops/embedding_op.cpp +++ b/oneflow/user/ops/embedding_op.cpp @@ -20,7 +20,7 @@ limitations under the License. namespace oneflow { /* static */ Maybe EmbeddingRenormOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/empty_op.cpp b/oneflow/user/ops/empty_op.cpp index 92582ad145d..958843bdb03 100644 --- a/oneflow/user/ops/empty_op.cpp +++ b/oneflow/user/ops/empty_op.cpp @@ -38,8 +38,8 @@ Maybe> MakeEmptyStream(const Symbol& out_device, const bo } // namespace /* static */ Maybe EmptyOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = Shape(ctx->Attr("shape").dim_vec()); - *ctx->OutputStride("out", 0) = Stride(Shape(ctx->Attr("shape").dim_vec())); + *ctx->MutOutputShape("out", 0) = Shape(ctx->Attr("shape").dim_vec()); + *ctx->MutOutputStride("out", 0) = Stride(Shape(ctx->Attr("shape").dim_vec())); return Maybe::Ok(); } @@ -52,8 +52,8 @@ Maybe> MakeEmptyStream(const Symbol& out_device, const bo GetTensorSliceView4ParallelId(parallel_hierarchy, nd_sbp, logical_shape, parallel_id); const Shape& physical_shape = tensor_slice_view.shape(); - *ctx->OutputShape("out", 0) = physical_shape; - *ctx->OutputStride("out", 0) = Stride(physical_shape); + *ctx->MutOutputShape("out", 0) = physical_shape; + *ctx->MutOutputStride("out", 0) = Stride(physical_shape); return Maybe::Ok(); } diff --git a/oneflow/user/ops/erfinv_op.cpp b/oneflow/user/ops/erfinv_op.cpp index 708e50c89c6..a0467942a39 100644 --- a/oneflow/user/ops/erfinv_op.cpp +++ b/oneflow/user/ops/erfinv_op.cpp @@ -20,7 +20,7 @@ namespace oneflow { /* static */ Maybe ErfInvOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); - Shape* y_shape = ctx->OutputShape("y", 0); + Shape* y_shape = ctx->MutOutputShape("y", 0); *y_shape = x_shape; return Maybe::Ok(); } diff --git a/oneflow/user/ops/expand_dims_op.cpp b/oneflow/user/ops/expand_dims_op.cpp index f5031f7a1b3..79392e43258 100644 --- a/oneflow/user/ops/expand_dims_op.cpp +++ b/oneflow/user/ops/expand_dims_op.cpp @@ -31,7 +31,7 @@ int32_t TransformNegativeAxisToPositive(int32_t axis, const int32_t num_axes) { /* static */ Maybe ExpandDimsOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& in_shape = ctx->InputShape("in", 0); - Shape* out_shape = ctx->OutputShape("out", 0); + Shape* out_shape = ctx->MutOutputShape("out", 0); const int32_t axis = TransformNegativeAxisToPositive(ctx->Attr("axis"), in_shape.NumAxes()); diff --git a/oneflow/user/ops/expand_op.cpp b/oneflow/user/ops/expand_op.cpp index 9e8cfd5c2ef..8837793c7a1 100644 --- a/oneflow/user/ops/expand_op.cpp +++ b/oneflow/user/ops/expand_op.cpp @@ -32,7 +32,7 @@ namespace oneflow { std::vector stride; CHECK_JUST(getOutShapeAndStrideForFp(in_shape, logical_expand_shape, out_shape, stride)); - Shape* output_shape = ctx->OutputShape("out", 0); + Shape* output_shape = ctx->MutOutputShape("out", 0); DimVector dim_vec(out_shape.begin(), out_shape.end()); *output_shape = Shape(dim_vec); @@ -90,7 +90,7 @@ namespace oneflow { CHECK_JUST(getOutShapeAndStrideForBp(logical_out_shape, logical_expand_shape, in_shape, out_shape, stride)); - Shape* output_shape = ctx->OutputShape("out", 0); + Shape* output_shape = ctx->MutOutputShape("out", 0); DimVector dim_vec(out_shape.begin(), out_shape.end()); *output_shape = Shape(dim_vec); return Maybe::Ok(); diff --git a/oneflow/user/ops/eye_op.cpp b/oneflow/user/ops/eye_op.cpp index 077758b2452..69823ff7943 100644 --- a/oneflow/user/ops/eye_op.cpp +++ b/oneflow/user/ops/eye_op.cpp @@ -21,7 +21,7 @@ namespace oneflow { /* static */ Maybe EyeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { int64_t rows = ctx->Attr("rows"); int64_t cols = ctx->Attr("cols"); - *ctx->OutputShape("out", 0) = Shape({rows, cols}); + *ctx->MutOutputShape("out", 0) = Shape({rows, cols}); return Maybe::Ok(); } diff --git a/oneflow/user/ops/fake_quantization_op.cpp b/oneflow/user/ops/fake_quantization_op.cpp index fbe6a7d8ca6..bc6dfe54a4b 100644 --- a/oneflow/user/ops/fake_quantization_op.cpp +++ b/oneflow/user/ops/fake_quantization_op.cpp @@ -30,7 +30,7 @@ namespace oneflow { CHECK_EQ_OR_RETURN(zero_point_shape.elem_cnt(), in_shape.At(0)); } - *ctx->OutputShape("out", 0) = in_shape; + *ctx->MutOutputShape("out", 0) = in_shape; return Maybe::Ok(); } diff --git a/oneflow/user/ops/fill_op.cpp b/oneflow/user/ops/fill_op.cpp index 854e9a311e7..064dd54a80c 100644 --- a/oneflow/user/ops/fill_op.cpp +++ b/oneflow/user/ops/fill_op.cpp @@ -20,9 +20,9 @@ namespace oneflow { /* static */ Maybe FillOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& in_shape = ctx->InputShape("in", 0); - Shape* out_shape = ctx->OutputShape("out", 0); + Shape* out_shape = ctx->MutOutputShape("out", 0); *out_shape = in_shape; - Stride* out_stride = ctx->OutputStride("out", 0); + Stride* out_stride = ctx->MutOutputStride("out", 0); *out_stride = ctx->InputStride("in", 0); return Maybe::Ok(); } @@ -46,9 +46,9 @@ namespace oneflow { /* static */ Maybe FillTensorOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& in_shape = ctx->InputShape("in", 0); - Shape* out_shape = ctx->OutputShape("out", 0); + Shape* out_shape = ctx->MutOutputShape("out", 0); *out_shape = in_shape; - Stride* out_stride = ctx->OutputStride("out", 0); + Stride* out_stride = ctx->MutOutputStride("out", 0); *out_stride = ctx->InputStride("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/fused_bias_add_op.cpp b/oneflow/user/ops/fused_bias_add_op.cpp index 46f9394ff18..378e9ed50fe 100644 --- a/oneflow/user/ops/fused_bias_add_op.cpp +++ b/oneflow/user/ops/fused_bias_add_op.cpp @@ -27,7 +27,7 @@ namespace oneflow { CHECK_GE_OR_RETURN(bias_add_axis, 0); CHECK_LT_OR_RETURN(bias_add_axis, a_tensor_desc.shape().NumAxes()); CHECK_EQ_OR_RETURN(a_tensor_desc.shape().At(bias_add_axis), b_tensor_desc.shape().At(0)); - *ctx->OutputShape("out", 0) = a_tensor_desc.shape(); + *ctx->MutOutputShape("out", 0) = a_tensor_desc.shape(); *ctx->OutputIsDynamic("out", 0) = a_tensor_desc.is_dynamic(); return Maybe::Ok(); } @@ -67,7 +67,7 @@ namespace oneflow { CHECK_GE_OR_RETURN(bias_add_axis, 0); CHECK_LT_OR_RETURN(bias_add_axis, a_tensor_desc.shape().NumAxes()); CHECK_EQ_OR_RETURN(a_tensor_desc.shape().At(bias_add_axis), b_tensor_desc.shape().At(0)); - *ctx->OutputShape("dx", 0) = a_tensor_desc.shape(); + *ctx->MutOutputShape("dx", 0) = a_tensor_desc.shape(); *ctx->OutputIsDynamic("dx", 0) = a_tensor_desc.is_dynamic(); return Maybe::Ok(); } @@ -152,7 +152,7 @@ REGISTER_USER_OP_GRAD("fused_bias_add_gelu") CHECK_LT_OR_RETURN(bias_add_axis, a_tensor_desc.shape().NumAxes()); CHECK_EQ_OR_RETURN(a_tensor_desc.shape().At(bias_add_axis), b_tensor_desc.shape().At(0)); CHECK_EQ_OR_RETURN(a_tensor_desc.shape(), mask_tensor_desc.shape()); - *ctx->OutputShape("out", 0) = a_tensor_desc.shape(); + *ctx->MutOutputShape("out", 0) = a_tensor_desc.shape(); *ctx->OutputIsDynamic("out", 0) = a_tensor_desc.is_dynamic(); return Maybe::Ok(); } diff --git a/oneflow/user/ops/fused_cross_feature_interaction_op.cpp b/oneflow/user/ops/fused_cross_feature_interaction_op.cpp index 0dfce53893d..e58a21284c7 100644 --- a/oneflow/user/ops/fused_cross_feature_interaction_op.cpp +++ b/oneflow/user/ops/fused_cross_feature_interaction_op.cpp @@ -24,11 +24,11 @@ namespace oneflow { const Shape& x_shape = ctx->InputShape("x", 0); const Shape& weight_shape = ctx->InputShape("weight", 0); CHECK_EQ_OR_RETURN(x_shape.At(1), weight_shape.At(1)) << "Matmul K dims should be equal. "; - *ctx->OutputShape("matmul_result", 0) = Shape({x_shape.At(0), weight_shape.At(0)}); + *ctx->MutOutputShape("matmul_result", 0) = Shape({x_shape.At(0), weight_shape.At(0)}); const Shape& x0_shape = ctx->InputShape("x0", 0); const Shape& bias_shape = ctx->InputShape("bias", 0); CHECK_EQ_OR_RETURN(bias_shape.At(0), x0_shape.At(1)) << "Bias dim should be equal to X0 dim1. "; - *ctx->OutputShape("out", 0) = x0_shape; + *ctx->MutOutputShape("out", 0) = x0_shape; return Maybe::Ok(); } @@ -59,10 +59,10 @@ namespace oneflow { user_op::InferContext* ctx) { const Shape& x0_shape = ctx->InputShape("x0", 0); const Shape& weight_shape = ctx->InputShape("weight", 0); - *ctx->OutputShape("dx0", 0) = x0_shape; - *ctx->OutputShape("dw", 0) = weight_shape; - *ctx->OutputShape("dx", 0) = x0_shape; - *ctx->OutputShape("dbias", 0) = Shape({x0_shape.At(1)}); + *ctx->MutOutputShape("dx0", 0) = x0_shape; + *ctx->MutOutputShape("dw", 0) = weight_shape; + *ctx->MutOutputShape("dx", 0) = x0_shape; + *ctx->MutOutputShape("dbias", 0) = Shape({x0_shape.At(1)}); return Maybe::Ok(); } @@ -100,10 +100,10 @@ namespace oneflow { user_op::InferContext* ctx) { const Shape& x0_shape = ctx->InputShape("x0", 0); const Shape& weight_shape = ctx->InputShape("weight", 0); - *ctx->OutputShape("dx0", 0) = x0_shape; - *ctx->OutputShape("dw", 0) = weight_shape; - *ctx->OutputShape("dx", 0) = x0_shape; - *ctx->OutputShape("dbias", 0) = Shape({x0_shape.At(1)}); + *ctx->MutOutputShape("dx0", 0) = x0_shape; + *ctx->MutOutputShape("dw", 0) = weight_shape; + *ctx->MutOutputShape("dx", 0) = x0_shape; + *ctx->MutOutputShape("dbias", 0) = Shape({x0_shape.At(1)}); return Maybe::Ok(); } diff --git a/oneflow/user/ops/fused_dot_feature_interaction_op.cpp b/oneflow/user/ops/fused_dot_feature_interaction_op.cpp index 0d99cf8b489..da1d256eb67 100644 --- a/oneflow/user/ops/fused_dot_feature_interaction_op.cpp +++ b/oneflow/user/ops/fused_dot_feature_interaction_op.cpp @@ -36,7 +36,7 @@ namespace oneflow { } const std::string& pooling = ctx->Attr("pooling"); if (pooling == "sum") { - *ctx->OutputShape("out", 0) = Shape({batch_size, vector_size}); + *ctx->MutOutputShape("out", 0) = Shape({batch_size, vector_size}); return Maybe::Ok(); } if (ctx->has_input("sparse_feature", 0)) { @@ -66,7 +66,7 @@ namespace oneflow { CHECK_EQ_OR_RETURN(output_concat_shape.At(0), batch_size); out_dim += output_concat_shape.At(1); } - *ctx->OutputShape("out", 0) = Shape({batch_size, out_dim}); + *ctx->MutOutputShape("out", 0) = Shape({batch_size, out_dim}); return Maybe::Ok(); } @@ -109,14 +109,14 @@ namespace oneflow { CHECK_EQ_OR_RETURN(ctx->output_size("features_grad"), ctx->input_size("features")) << "features_grad and features must have same size"; for (int64_t i = 0; i < ctx->output_size("features_grad"); ++i) { - *ctx->OutputShape("features_grad", i) = ctx->InputShape("features", i); + *ctx->MutOutputShape("features_grad", i) = ctx->InputShape("features", i); } if (ctx->has_output("output_concat_grad", 0)) { const int32_t output_concat_grad_dim = ctx->Attr("output_concat_grad_dim"); - *ctx->OutputShape("output_concat_grad", 0) = Shape({batch_size, output_concat_grad_dim}); + *ctx->MutOutputShape("output_concat_grad", 0) = Shape({batch_size, output_concat_grad_dim}); } if (ctx->has_output("sparse_feature_grad", 0)) { - *ctx->OutputShape("sparse_feature_grad", 0) = ctx->InputShape("sparse_feature", 0); + *ctx->MutOutputShape("sparse_feature_grad", 0) = ctx->InputShape("sparse_feature", 0); } return Maybe::Ok(); } diff --git a/oneflow/user/ops/fused_gru_cell_op.cpp b/oneflow/user/ops/fused_gru_cell_op.cpp index b9b6b7063f1..7b3aaee0e31 100644 --- a/oneflow/user/ops/fused_gru_cell_op.cpp +++ b/oneflow/user/ops/fused_gru_cell_op.cpp @@ -21,8 +21,8 @@ namespace oneflow { /* static */ Maybe FusedGruCellOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& hx_shape = ctx->InputShape("hx", 0); - *ctx->OutputShape("hy", 0) = hx_shape; - *ctx->OutputShape("workspace", 0) = Shape({hx_shape.At(0), hx_shape.At(1) * 5}); + *ctx->MutOutputShape("hy", 0) = hx_shape; + *ctx->MutOutputShape("workspace", 0) = Shape({hx_shape.At(0), hx_shape.At(1) * 5}); return Maybe::Ok(); } @@ -69,14 +69,14 @@ namespace oneflow { /* static */ Maybe FusedGruCellGradOp ::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& grad_hy_shape = ctx->InputShape("grad_hy", 0); DimVector dim_vec({grad_hy_shape.At(0), grad_hy_shape.At(1) * 3}); - *ctx->OutputShape("grad_input_gates", 0) = Shape(dim_vec); - *ctx->OutputShape("grad_hidden_gates", 0) = Shape(dim_vec); + *ctx->MutOutputShape("grad_input_gates", 0) = Shape(dim_vec); + *ctx->MutOutputShape("grad_hidden_gates", 0) = Shape(dim_vec); - if (ctx->has_output("grad_hx", 0)) { *ctx->OutputShape("grad_hx", 0) = grad_hy_shape; } + if (ctx->has_output("grad_hx", 0)) { *ctx->MutOutputShape("grad_hx", 0) = grad_hy_shape; } if (ctx->has_output("grad_input_bias", 0) && ctx->has_output("grad_hidden_bias", 0)) { - *ctx->OutputShape("grad_input_bias", 0) = Shape({grad_hy_shape.At(1) * 3}); - *ctx->OutputShape("grad_hidden_bias", 0) = Shape({grad_hy_shape.At(1) * 3}); + *ctx->MutOutputShape("grad_input_bias", 0) = Shape({grad_hy_shape.At(1) * 3}); + *ctx->MutOutputShape("grad_hidden_bias", 0) = Shape({grad_hy_shape.At(1) * 3}); } return Maybe::Ok(); diff --git a/oneflow/user/ops/fused_lstm_cell_op.cpp b/oneflow/user/ops/fused_lstm_cell_op.cpp index 5ce8add4f7b..8cf2663e04c 100644 --- a/oneflow/user/ops/fused_lstm_cell_op.cpp +++ b/oneflow/user/ops/fused_lstm_cell_op.cpp @@ -21,9 +21,9 @@ namespace oneflow { /* static */ Maybe FusedLstmCellOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& cx_shape = ctx->InputShape("cx", 0); - *ctx->OutputShape("hy", 0) = cx_shape; - *ctx->OutputShape("cy", 0) = cx_shape; - *ctx->OutputShape("workspace", 0) = ctx->InputShape("input_gates", 0); + *ctx->MutOutputShape("hy", 0) = cx_shape; + *ctx->MutOutputShape("cy", 0) = cx_shape; + *ctx->MutOutputShape("workspace", 0) = ctx->InputShape("input_gates", 0); return Maybe::Ok(); } @@ -71,12 +71,14 @@ namespace oneflow { } /* static */ Maybe FusedLstmCellGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("grad_gates", 0) = ctx->InputShape("workspace", 0); + *ctx->MutOutputShape("grad_gates", 0) = ctx->InputShape("workspace", 0); - if (ctx->has_output("grad_cx", 0)) { *ctx->OutputShape("grad_cx", 0) = ctx->InputShape("cx", 0); } + if (ctx->has_output("grad_cx", 0)) { + *ctx->MutOutputShape("grad_cx", 0) = ctx->InputShape("cx", 0); + } if (ctx->has_output("grad_bias", 0)) { - *ctx->OutputShape("grad_bias", 0) = Shape({ctx->InputShape("workspace", 0).At(1)}); + *ctx->MutOutputShape("grad_bias", 0) = Shape({ctx->InputShape("workspace", 0).At(1)}); } return Maybe::Ok(); diff --git a/oneflow/user/ops/fused_matmul_bias_add_relu_dropout_op.cpp b/oneflow/user/ops/fused_matmul_bias_add_relu_dropout_op.cpp index c473ba7ea57..ced41d69fd8 100644 --- a/oneflow/user/ops/fused_matmul_bias_add_relu_dropout_op.cpp +++ b/oneflow/user/ops/fused_matmul_bias_add_relu_dropout_op.cpp @@ -65,12 +65,12 @@ Maybe InferTensorDesc4FusedMatmul(user_op::InferContext* ctx) { // Set Middle result shape. long cublas_aligned_aux_ld = AlignReluAuxLd(cublas_aux_ld); int64_t aux_size = cublas_aligned_aux_ld / 32; // Cause we use int32_t as dtype - *ctx->OutputShape("cublas_aux", idx) = Shape({m, aux_size}); - *ctx->OutputShape("hidden", idx) = Shape({m, n}); + *ctx->MutOutputShape("cublas_aux", idx) = Shape({m, aux_size}); + *ctx->MutOutputShape("hidden", idx) = Shape({m, n}); // Set for next layer. k = n; } - *ctx->OutputShape("out", 0) = {m, n}; + *ctx->MutOutputShape("out", 0) = {m, n}; return Maybe::Ok(); } diff --git a/oneflow/user/ops/fused_relu_dropout_grad_op.cpp b/oneflow/user/ops/fused_relu_dropout_grad_op.cpp index 14101dd16c5..5de869d6a45 100644 --- a/oneflow/user/ops/fused_relu_dropout_grad_op.cpp +++ b/oneflow/user/ops/fused_relu_dropout_grad_op.cpp @@ -25,7 +25,7 @@ namespace oneflow { namespace { Maybe InferTensorDesc4FusedReluDropoutGrad(user_op::InferContext* ctx) { - *ctx->OutputShape("dx", 0) = ctx->InputShape("dy", 0); + *ctx->MutOutputShape("dx", 0) = ctx->InputShape("dy", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/fused_scale_mask_softmax_dropout_op.cpp b/oneflow/user/ops/fused_scale_mask_softmax_dropout_op.cpp index eabeed57b06..0d9973a79fb 100644 --- a/oneflow/user/ops/fused_scale_mask_softmax_dropout_op.cpp +++ b/oneflow/user/ops/fused_scale_mask_softmax_dropout_op.cpp @@ -27,9 +27,9 @@ namespace oneflow { CHECK_EQ_OR_RETURN(x_desc.shape().At(x_shape.NumAxes() - 1), mask_desc.shape().At(mask_shape.NumAxes() - 1)) << " last dim of x and mask is not equal."; - *ctx->OutputShape("y", 0) = x_desc.shape(); + *ctx->MutOutputShape("y", 0) = x_desc.shape(); *ctx->OutputIsDynamic("y", 0) = x_desc.is_dynamic(); - *ctx->OutputShape("softmax_y", 0) = x_desc.shape(); + *ctx->MutOutputShape("softmax_y", 0) = x_desc.shape(); *ctx->OutputIsDynamic("softmax_y", 0) = x_desc.is_dynamic(); return Maybe::Ok(); } diff --git a/oneflow/user/ops/fused_scale_mask_softmax_op.cpp b/oneflow/user/ops/fused_scale_mask_softmax_op.cpp index 235e897db47..d8d6ceda8f7 100644 --- a/oneflow/user/ops/fused_scale_mask_softmax_op.cpp +++ b/oneflow/user/ops/fused_scale_mask_softmax_op.cpp @@ -27,7 +27,7 @@ namespace oneflow { CHECK_EQ_OR_RETURN(x_desc.shape().At(x_shape.NumAxes() - 1), mask_desc.shape().At(mask_shape.NumAxes() - 1)) << " last dim of x and mask is not equal."; - *ctx->OutputShape("y", 0) = x_desc.shape(); + *ctx->MutOutputShape("y", 0) = x_desc.shape(); *ctx->OutputIsDynamic("y", 0) = x_desc.is_dynamic(); return Maybe::Ok(); } diff --git a/oneflow/user/ops/fused_scale_tril_softmax_mask_scale_op.cpp b/oneflow/user/ops/fused_scale_tril_softmax_mask_scale_op.cpp index 20dead6c8d7..77dd85f57a4 100644 --- a/oneflow/user/ops/fused_scale_tril_softmax_mask_scale_op.cpp +++ b/oneflow/user/ops/fused_scale_tril_softmax_mask_scale_op.cpp @@ -20,9 +20,9 @@ namespace oneflow { /*static*/ auto FusedTrilScaleSoftmaxMaskScaleOp::InferLogicalTensorDesc(user_op::InferContext* ctx) -> Maybe { const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); - *ctx->OutputShape("y", 0) = x_desc.shape(); + *ctx->MutOutputShape("y", 0) = x_desc.shape(); *ctx->OutputIsDynamic("y", 0) = x_desc.is_dynamic(); - *ctx->OutputShape("softmax_y", 0) = x_desc.shape(); + *ctx->MutOutputShape("softmax_y", 0) = x_desc.shape(); *ctx->OutputIsDynamic("softmax_y", 0) = x_desc.is_dynamic(); return Maybe::Ok(); } diff --git a/oneflow/user/ops/fused_self_attention_query_mul_key_and_value_ops.cpp b/oneflow/user/ops/fused_self_attention_query_mul_key_and_value_ops.cpp index 232a78189c9..4afaa120388 100644 --- a/oneflow/user/ops/fused_self_attention_query_mul_key_and_value_ops.cpp +++ b/oneflow/user/ops/fused_self_attention_query_mul_key_and_value_ops.cpp @@ -41,8 +41,8 @@ namespace oneflow { CHECK_EQ_OR_RETURN(hidden_size % (head_size * 3), 0); int64_t num_heads = hidden_size / (head_size * 3); - *ctx->OutputShape("query_mul_key", 0) = Shape({batch_size, num_heads, seq_len, seq_len}); - *ctx->OutputShape("value", 0) = Shape({batch_size, num_heads, seq_len, head_size}); + *ctx->MutOutputShape("query_mul_key", 0) = Shape({batch_size, num_heads, seq_len, seq_len}); + *ctx->MutOutputShape("value", 0) = Shape({batch_size, num_heads, seq_len, head_size}); return Maybe::Ok(); } @@ -98,7 +98,7 @@ namespace oneflow { CHECK_EQ_OR_RETURN(qmk_grad_shape.At(2), seq_len); CHECK_EQ_OR_RETURN(qmk_grad_shape.At(3), seq_len); - *ctx->OutputShape("hidden_states_grad", 0) = h_shape; + *ctx->MutOutputShape("hidden_states_grad", 0) = h_shape; return Maybe::Ok(); } /*static*/ auto FusedSelfAttentionQueryMulKeyAndValueGradOp::InferPhysicalTensorDesc( diff --git a/oneflow/user/ops/gelu_op.cpp b/oneflow/user/ops/gelu_op.cpp index 39f12592c23..50c2012c83e 100644 --- a/oneflow/user/ops/gelu_op.cpp +++ b/oneflow/user/ops/gelu_op.cpp @@ -20,7 +20,7 @@ namespace oneflow { /*static*/ auto GeluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) -> Maybe { const Shape& in_shape = ctx->InputShape("in", 0); - Shape* out_shape = ctx->OutputShape("out", 0); + Shape* out_shape = ctx->MutOutputShape("out", 0); *out_shape = in_shape; return Maybe::Ok(); } @@ -42,7 +42,7 @@ namespace oneflow { /*static*/ auto GeluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) -> Maybe { const Shape& x_shape = ctx->InputShape("x", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); + Shape* dx_shape = ctx->MutOutputShape("dx", 0); CHECK_OR_RETURN(dy_shape == x_shape); *dx_shape = dy_shape; return Maybe::Ok(); diff --git a/oneflow/user/ops/generate_random_batch_permutation_indices_op.cpp b/oneflow/user/ops/generate_random_batch_permutation_indices_op.cpp index 73b7dcb52eb..7d929383f99 100644 --- a/oneflow/user/ops/generate_random_batch_permutation_indices_op.cpp +++ b/oneflow/user/ops/generate_random_batch_permutation_indices_op.cpp @@ -21,7 +21,7 @@ namespace oneflow { /*static*/ auto GenerateRandomBatchPermutationIndicesOp::InferLogicalTensorDesc( user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("y", 0) = Shape({ctx->InputShape("x", 0).At(0)}); + *ctx->MutOutputShape("y", 0) = Shape({ctx->InputShape("x", 0).At(0)}); return Maybe::Ok(); } /*static*/ auto GenerateRandomBatchPermutationIndicesOp::InferPhysicalTensorDesc( diff --git a/oneflow/user/ops/hardshrink_op.cpp b/oneflow/user/ops/hardshrink_op.cpp index 21fdae26a17..362818758b3 100644 --- a/oneflow/user/ops/hardshrink_op.cpp +++ b/oneflow/user/ops/hardshrink_op.cpp @@ -19,7 +19,7 @@ limitations under the License. namespace oneflow { /* static */ Maybe HardShrinkOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); return Maybe::Ok(); } @@ -43,7 +43,7 @@ namespace oneflow { /* static */ Maybe HardShrinkGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& y_shape = ctx->InputShape("y", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); + Shape* dx_shape = ctx->MutOutputShape("dx", 0); CHECK_OR_RETURN(dy_shape == y_shape) << "The shape of y_grad and y must be same."; *dx_shape = dy_shape; return Maybe::Ok(); diff --git a/oneflow/user/ops/hardsigmoid_op.cpp b/oneflow/user/ops/hardsigmoid_op.cpp index 887614425ac..f56d3392058 100644 --- a/oneflow/user/ops/hardsigmoid_op.cpp +++ b/oneflow/user/ops/hardsigmoid_op.cpp @@ -20,7 +20,7 @@ namespace oneflow { /* static */ Maybe HardsigmoidOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& in_shape = ctx->InputShape("in", 0); - Shape* out_shape = ctx->OutputShape("out", 0); + Shape* out_shape = ctx->MutOutputShape("out", 0); *out_shape = in_shape; return Maybe::Ok(); } @@ -45,7 +45,7 @@ namespace oneflow { /* static */ Maybe HardsigmoidGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); + Shape* dx_shape = ctx->MutOutputShape("dx", 0); CHECK_OR_RETURN(dy_shape == x_shape); *dx_shape = dy_shape; return Maybe::Ok(); diff --git a/oneflow/user/ops/hardswish_op.cpp b/oneflow/user/ops/hardswish_op.cpp index f7dfbc5c870..3342e1d4dbb 100644 --- a/oneflow/user/ops/hardswish_op.cpp +++ b/oneflow/user/ops/hardswish_op.cpp @@ -19,7 +19,7 @@ limitations under the License. namespace oneflow { /* static */ Maybe HardswishOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); return Maybe::Ok(); } @@ -43,7 +43,7 @@ namespace oneflow { /* static */ Maybe HardswishGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); + Shape* dx_shape = ctx->MutOutputShape("dx", 0); CHECK_OR_RETURN(dy_shape == x_shape); *dx_shape = dy_shape; return Maybe::Ok(); diff --git a/oneflow/user/ops/hardtanh_op.cpp b/oneflow/user/ops/hardtanh_op.cpp index 2d5208c7b0b..d2033b79870 100644 --- a/oneflow/user/ops/hardtanh_op.cpp +++ b/oneflow/user/ops/hardtanh_op.cpp @@ -20,7 +20,7 @@ namespace oneflow { /* static */ Maybe HardtanhOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& in_shape = ctx->InputShape("in", 0); - Shape* out_shape = ctx->OutputShape("out", 0); + Shape* out_shape = ctx->MutOutputShape("out", 0); *out_shape = in_shape; double min_val = ctx->Attr("min_val"); double max_val = ctx->Attr("max_val"); @@ -48,7 +48,7 @@ namespace oneflow { /* static */ Maybe HardtanhGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& y_shape = ctx->InputShape("y", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); + Shape* dx_shape = ctx->MutOutputShape("dx", 0); CHECK_OR_RETURN(dy_shape == y_shape); *dx_shape = dy_shape; double min_val = ctx->Attr("min_val"); diff --git a/oneflow/user/ops/hierarchical_parallel_cast_op.cpp b/oneflow/user/ops/hierarchical_parallel_cast_op.cpp index 7ddad5a603f..564960b6e66 100644 --- a/oneflow/user/ops/hierarchical_parallel_cast_op.cpp +++ b/oneflow/user/ops/hierarchical_parallel_cast_op.cpp @@ -21,7 +21,7 @@ namespace oneflow { /* static */ Maybe HierarchicalParallelCastOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } @@ -57,7 +57,7 @@ namespace oneflow { /* static */ Maybe HierarchicalParallelCastLikeOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/identity_op.cpp b/oneflow/user/ops/identity_op.cpp index 538abeb5dde..10deb96ce54 100644 --- a/oneflow/user/ops/identity_op.cpp +++ b/oneflow/user/ops/identity_op.cpp @@ -19,7 +19,7 @@ limitations under the License. namespace oneflow { /* static */ Maybe IdentityOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/image_object_preprocess_ops.cpp b/oneflow/user/ops/image_object_preprocess_ops.cpp index 5fd2cb99f38..d2b523ec994 100644 --- a/oneflow/user/ops/image_object_preprocess_ops.cpp +++ b/oneflow/user/ops/image_object_preprocess_ops.cpp @@ -35,7 +35,7 @@ Maybe ImageObjectGetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& flip_code_desc = ctx->InputTensorDesc("flip_code", 0); CHECK_EQ_OR_RETURN(flip_code_desc.shape().elem_cnt(), N); - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } @@ -66,7 +66,7 @@ Maybe ImageObjectGetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& flip_code_desc = ctx->InputTensorDesc("flip_code", 0); CHECK_EQ_OR_RETURN(flip_code_desc.shape().elem_cnt(), N); - *ctx->OutputShape("out", 0) = ctx->InputShape("bbox", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("bbox", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("bbox", 0); return Maybe::Ok(); } @@ -98,7 +98,7 @@ Maybe ImageObjectGetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& scale_desc = ctx->InputTensorDesc("scale", 0); CHECK_EQ_OR_RETURN(scale_desc.shape().elem_cnt(), N * 2); - *ctx->OutputShape("out", 0) = ctx->InputShape("bbox", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("bbox", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("bbox", 0); return Maybe::Ok(); } @@ -132,7 +132,7 @@ Maybe ImageObjectGetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& flip_code_desc = ctx->InputTensorDesc("flip_code", 0); CHECK_EQ_OR_RETURN(flip_code_desc.shape().elem_cnt(), N); - *ctx->OutputShape("out", 0) = ctx->InputShape("poly", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("poly", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("poly", 0); return Maybe::Ok(); } @@ -167,7 +167,7 @@ Maybe ImageObjectGetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& scale_desc = ctx->InputTensorDesc("scale", 0); CHECK_EQ_OR_RETURN(scale_desc.shape().elem_cnt(), N * 2); - *ctx->OutputShape("out", 0) = ctx->InputShape("poly", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("poly", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("poly", 0); return Maybe::Ok(); } @@ -194,7 +194,7 @@ Maybe ImageObjectGetSbp(user_op::SbpContext* ctx) { /* static */ Maybe ImageNormalizeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); CHECK_EQ_OR_RETURN(in_desc.shape().NumAxes(), 1); - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } @@ -227,7 +227,7 @@ Maybe ImageObjectGetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& image_size_desc = ctx->InputTensorDesc("image_size", 0); CHECK_EQ_OR_RETURN(image_size_desc.shape().elem_cnt(), N * 2); - *ctx->OutputShape("out", 0) = ctx->InputShape("poly", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("poly", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("poly", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/image_preprocess_ops.cpp b/oneflow/user/ops/image_preprocess_ops.cpp index 00c6d419c8b..20985964a94 100644 --- a/oneflow/user/ops/image_preprocess_ops.cpp +++ b/oneflow/user/ops/image_preprocess_ops.cpp @@ -159,7 +159,7 @@ namespace oneflow { const auto tensor_slice_view = GetTensorSliceView4ParallelId(parallel_hierarchy, nd_sbp, logical_shape, parallel_id); const Shape& physical_shape = tensor_slice_view.shape(); - *ctx->OutputShape("out", 0) = physical_shape; + *ctx->MutOutputShape("out", 0) = physical_shape; return Maybe::Ok(); } diff --git a/oneflow/user/ops/l1_l2_regularize_gradient_op.cpp b/oneflow/user/ops/l1_l2_regularize_gradient_op.cpp index 05affa22404..7b57a21bd01 100644 --- a/oneflow/user/ops/l1_l2_regularize_gradient_op.cpp +++ b/oneflow/user/ops/l1_l2_regularize_gradient_op.cpp @@ -24,7 +24,7 @@ Maybe InferTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& model = ctx->InputTensorDesc("model", 0); const user_op::TensorDesc& model_diff = ctx->InputTensorDesc("model_diff", 0); CHECK_EQ_OR_RETURN(model_diff.shape(), model.shape()); - *ctx->OutputShape("out", 0) = ctx->InputShape("model", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("model", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("model", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/l2_normalize_op.cpp b/oneflow/user/ops/l2_normalize_op.cpp index d1723c41c97..4fed45fad79 100644 --- a/oneflow/user/ops/l2_normalize_op.cpp +++ b/oneflow/user/ops/l2_normalize_op.cpp @@ -20,8 +20,8 @@ namespace oneflow { /* static */ Maybe L2NormalizeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); - Shape* y_shape = ctx->OutputShape("y", 0); - Shape* square_x_sum_shape = ctx->OutputShape("square_x_sum", 0); + Shape* y_shape = ctx->MutOutputShape("y", 0); + Shape* square_x_sum_shape = ctx->MutOutputShape("square_x_sum", 0); const int32_t axis = ctx->Attr("axis"); const float epsilon = ctx->Attr("epsilon"); CHECK_GE_OR_RETURN(axis, 0); @@ -62,7 +62,7 @@ namespace oneflow { const Shape& dy_shape = ctx->InputShape("dy", 0); const Shape& y_shape = ctx->InputShape("y", 0); const Shape& square_x_sum_shape = ctx->InputShape("square_x_sum", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); + Shape* dx_shape = ctx->MutOutputShape("dx", 0); const int32_t axis = ctx->Attr("axis"); const float epsilon = ctx->Attr("epsilon"); CHECK_EQ_OR_RETURN(dy_shape, y_shape); diff --git a/oneflow/user/ops/leaky_relu_op.cpp b/oneflow/user/ops/leaky_relu_op.cpp index 09d8b318c54..fb43e8a2bf2 100644 --- a/oneflow/user/ops/leaky_relu_op.cpp +++ b/oneflow/user/ops/leaky_relu_op.cpp @@ -20,7 +20,7 @@ namespace oneflow { /* static */ Maybe LeakyReluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); - Shape* y_shape = ctx->OutputShape("y", 0); + Shape* y_shape = ctx->MutOutputShape("y", 0); *y_shape = x_shape; return Maybe::Ok(); } @@ -45,7 +45,7 @@ namespace oneflow { /* static */ Maybe LeakyReluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); + Shape* dx_shape = ctx->MutOutputShape("dx", 0); CHECK_OR_RETURN(dy_shape == x_shape); *dx_shape = dy_shape; return Maybe::Ok(); diff --git a/oneflow/user/ops/log_softmax_op.cpp b/oneflow/user/ops/log_softmax_op.cpp index d8cffbf7460..8064d78941c 100644 --- a/oneflow/user/ops/log_softmax_op.cpp +++ b/oneflow/user/ops/log_softmax_op.cpp @@ -19,7 +19,7 @@ limitations under the License. namespace oneflow { /* static */ Maybe LogSoftmaxOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("prob", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("prob", 0) = ctx->InputShape("in", 0); return Maybe::Ok(); } @@ -46,7 +46,7 @@ namespace oneflow { /* static */ Maybe LogSoftmaxGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& y_shape = ctx->InputShape("prob", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); + Shape* dx_shape = ctx->MutOutputShape("dx", 0); CHECK_OR_RETURN(dy_shape == y_shape); *dx_shape = dy_shape; return Maybe::Ok(); diff --git a/oneflow/user/ops/masked_fill_op.cpp b/oneflow/user/ops/masked_fill_op.cpp index 327ce994ded..f4cf83edbe5 100644 --- a/oneflow/user/ops/masked_fill_op.cpp +++ b/oneflow/user/ops/masked_fill_op.cpp @@ -22,7 +22,7 @@ namespace { Maybe InferMaskedFillTensorDesc(user_op::InferContext* ctx) { const Shape& mask_shape = ctx->InputShape("mask", 0); - *ctx->OutputShape("out", 0) = mask_shape; + *ctx->MutOutputShape("out", 0) = mask_shape; return Maybe::Ok(); } diff --git a/oneflow/user/ops/math_binary_broadcast_ops.cpp b/oneflow/user/ops/math_binary_broadcast_ops.cpp index 0c4ef770ac3..10ad55d4c0c 100644 --- a/oneflow/user/ops/math_binary_broadcast_ops.cpp +++ b/oneflow/user/ops/math_binary_broadcast_ops.cpp @@ -35,21 +35,21 @@ Maybe InferTensorDescBinaryBroadcastNormal(user_op::InferContext* ctx) { size_t output_num_axes = std::max(tensor_x.shape().NumAxes(), tensor_y.shape().NumAxes()); if (IsZeroDimTensor(&tensor_x)) { - *ctx->OutputShape("z", 0) = ctx->InputShape("y", 0); + *ctx->MutOutputShape("z", 0) = ctx->InputShape("y", 0); *ctx->OutputIsDynamic("z", 0) = ctx->InputIsDynamic("y", 0); } else if (IsZeroDimTensor(&tensor_y)) { - *ctx->OutputShape("z", 0) = ctx->InputShape("x", 0); + *ctx->MutOutputShape("z", 0) = ctx->InputShape("x", 0); *ctx->OutputIsDynamic("z", 0) = ctx->InputIsDynamic("x", 0); } else if (IsScalarTensor(&tensor_x)) { - *ctx->OutputShape("z", 0) = ctx->InputShape("y", 0); + *ctx->MutOutputShape("z", 0) = ctx->InputShape("y", 0); *ctx->OutputIsDynamic("z", 0) = ctx->InputIsDynamic("y", 0); } else if (IsScalarTensor(&tensor_y)) { - *ctx->OutputShape("z", 0) = ctx->InputShape("x", 0); + *ctx->MutOutputShape("z", 0) = ctx->InputShape("x", 0); *ctx->OutputIsDynamic("z", 0) = ctx->InputIsDynamic("x", 0); } else { const auto& x_shape = CreateLeftExtendedShape(ShapeView(tensor_x.shape()), output_num_axes); const auto& y_shape = CreateLeftExtendedShape(ShapeView(tensor_y.shape()), output_num_axes); - *ctx->OutputShape("z", 0) = ctx->InputShape("x", 0); + *ctx->MutOutputShape("z", 0) = ctx->InputShape("x", 0); *ctx->OutputIsDynamic("z", 0) = ctx->InputIsDynamic("x", 0); Shape out_shape(x_shape); FOR_RANGE(int64_t, i, 0, x_shape.NumAxes()) { diff --git a/oneflow/user/ops/matmul_op.cpp b/oneflow/user/ops/matmul_op.cpp index 9604177ed77..9996bd34850 100644 --- a/oneflow/user/ops/matmul_op.cpp +++ b/oneflow/user/ops/matmul_op.cpp @@ -36,7 +36,7 @@ Maybe InferTensorDesc4Matmul(user_op::InferContext* ctx) { user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - *ctx->OutputShape("out", 0) = ctx->InputShape("a", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("a", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("a", 0); int64_t m, n, k; // tensor a (no trans): m*k, tensor b (no trans): k*n diff --git a/oneflow/user/ops/matrix_vector_product_op.cpp b/oneflow/user/ops/matrix_vector_product_op.cpp index 91cfba1224b..fd987d0745b 100644 --- a/oneflow/user/ops/matrix_vector_product_op.cpp +++ b/oneflow/user/ops/matrix_vector_product_op.cpp @@ -26,7 +26,7 @@ Maybe InferTensorDesc4MatrixVectorProduct(user_op::InferContext* ctx) { int64_t m = a.shape().At(0); int64_t k = a.shape().At(1); CHECK_EQ_OR_RETURN(k, b.shape().At(0)) << "Dim K should be equal to vector b's dim0. "; - *ctx->OutputShape("out", 0) = Shape({m}); + *ctx->MutOutputShape("out", 0) = Shape({m}); return Maybe::Ok(); } @@ -47,7 +47,7 @@ Maybe InferTensorDesc4MatrixVectorProductGradA(user_op::InferContext* ctx) const user_op::TensorDesc& b = ctx->InputTensorDesc("b", 0); int64_t m = dy.shape().At(0); int64_t n = b.shape().At(0); - *ctx->OutputShape("dx", 0) = Shape({m, n}); + *ctx->MutOutputShape("dx", 0) = Shape({m, n}); return Maybe::Ok(); } @@ -58,7 +58,7 @@ Maybe InferTensorDesc4MatrixVectorProductGradB(user_op::InferContext* ctx) */ const user_op::TensorDesc& a = ctx->InputTensorDesc("a", 0); int64_t n = a.shape().At(1); - *ctx->OutputShape("dx", 0) = Shape({n}); + *ctx->MutOutputShape("dx", 0) = Shape({n}); return Maybe::Ok(); } diff --git a/oneflow/user/ops/median_op.cpp b/oneflow/user/ops/median_op.cpp index 5ca4689b037..9c80743b588 100644 --- a/oneflow/user/ops/median_op.cpp +++ b/oneflow/user/ops/median_op.cpp @@ -28,7 +28,7 @@ namespace oneflow { } /*static*/ Maybe MedianOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& ones_shape = {1}; - *ctx->OutputShape("output", 0) = ones_shape.RemoveOnes({0}); + *ctx->MutOutputShape("output", 0) = ones_shape.RemoveOnes({0}); return Maybe::Ok(); } /*static*/ Maybe MedianOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { diff --git a/oneflow/user/ops/median_with_indices_op.cpp b/oneflow/user/ops/median_with_indices_op.cpp index d9d0d672735..2aab4ccb8cf 100644 --- a/oneflow/user/ops/median_with_indices_op.cpp +++ b/oneflow/user/ops/median_with_indices_op.cpp @@ -31,8 +31,8 @@ namespace oneflow { } /*static*/ Maybe MedianWithIndicesOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& input_shape = ctx->InputShape("input", 0); - Shape* values_shape = ctx->OutputShape("values", 0); - Shape* indices_shape = ctx->OutputShape("indices", 0); + Shape* values_shape = ctx->MutOutputShape("values", 0); + Shape* indices_shape = ctx->MutOutputShape("indices", 0); const Shape& reduce_shape = CreateReducedShape(input_shape, {-1}); *values_shape = reduce_shape.RemoveOnes({-1}); *indices_shape = reduce_shape.RemoveOnes({-1}); diff --git a/oneflow/user/ops/min_max_observer_op.cpp b/oneflow/user/ops/min_max_observer_op.cpp index 3d7f186c378..84b68b8cdec 100644 --- a/oneflow/user/ops/min_max_observer_op.cpp +++ b/oneflow/user/ops/min_max_observer_op.cpp @@ -23,16 +23,16 @@ namespace oneflow { if (ctx->Attr("quantization_formula") == "google") { if (ctx->Attr("per_layer_quantization") == true) { - *ctx->OutputShape("scale", 0) = Shape({1}); - *ctx->OutputShape("zero_point", 0) = Shape({1}); + *ctx->MutOutputShape("scale", 0) = Shape({1}); + *ctx->MutOutputShape("zero_point", 0) = Shape({1}); } else { // NOTE(Liang Depeng): For now per-channel quantization only support axis 0 - *ctx->OutputShape("scale", 0) = Shape({in_shape.At(0)}); - *ctx->OutputShape("zero_point", 0) = Shape({in_shape.At(0)}); + *ctx->MutOutputShape("scale", 0) = Shape({in_shape.At(0)}); + *ctx->MutOutputShape("zero_point", 0) = Shape({in_shape.At(0)}); } } else { // quantization_formula == "cambricon" - *ctx->OutputShape("scale", 0) = Shape({1}); - *ctx->OutputShape("zero_point", 0) = Shape({1}); + *ctx->MutOutputShape("scale", 0) = Shape({1}); + *ctx->MutOutputShape("zero_point", 0) = Shape({1}); } return Maybe::Ok(); } diff --git a/oneflow/user/ops/mish_op.cpp b/oneflow/user/ops/mish_op.cpp index bee4ebb18a8..58dd37fdda5 100644 --- a/oneflow/user/ops/mish_op.cpp +++ b/oneflow/user/ops/mish_op.cpp @@ -19,7 +19,7 @@ limitations under the License. namespace oneflow { /* static */ Maybe MishOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); return Maybe::Ok(); } @@ -43,7 +43,7 @@ namespace oneflow { /* static */ Maybe MishGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); + Shape* dx_shape = ctx->MutOutputShape("dx", 0); CHECK_OR_RETURN(dy_shape == x_shape); *dx_shape = dy_shape; return Maybe::Ok(); diff --git a/oneflow/user/ops/model_update_ops.cpp b/oneflow/user/ops/model_update_ops.cpp index 0bcaf045247..cbfbf4b78bf 100644 --- a/oneflow/user/ops/model_update_ops.cpp +++ b/oneflow/user/ops/model_update_ops.cpp @@ -752,7 +752,7 @@ Maybe InferLarsUpdateDataType(user_op::InferContext* ctx) { /* static */ Maybe AdamBiasCorrectionFactorOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("train_step", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("train_step", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/moving_average_min_max_observer_op.cpp b/oneflow/user/ops/moving_average_min_max_observer_op.cpp index 434865f2d59..4e374c2de45 100644 --- a/oneflow/user/ops/moving_average_min_max_observer_op.cpp +++ b/oneflow/user/ops/moving_average_min_max_observer_op.cpp @@ -31,8 +31,8 @@ namespace oneflow { CHECK_OR_RETURN(current_train_step.NumAxes() == 1 && current_train_step.At(0) == 1); - *ctx->OutputShape("scale", 0) = Shape({1}); - *ctx->OutputShape("zero_point", 0) = Shape({1}); + *ctx->MutOutputShape("scale", 0) = Shape({1}); + *ctx->MutOutputShape("zero_point", 0) = Shape({1}); return Maybe::Ok(); } diff --git a/oneflow/user/ops/multi_reduce_ops.cpp b/oneflow/user/ops/multi_reduce_ops.cpp index 58ceca4ff10..89022884317 100644 --- a/oneflow/user/ops/multi_reduce_ops.cpp +++ b/oneflow/user/ops/multi_reduce_ops.cpp @@ -23,7 +23,7 @@ namespace { Maybe InferMultiReduceOpShape(user_op::InferContext* ctx) { CHECK_GT_OR_RETURN(ctx->input_size("x"), 0) << ctx->op_name() << "must have at least 1 input"; - *ctx->OutputShape("y", 0) = Shape({}); + *ctx->MutOutputShape("y", 0) = Shape({}); return Maybe::Ok(); } @@ -67,13 +67,13 @@ Maybe InferLocalMultiReduceOpLogicalShape(user_op::InferContext* ctx) { for (int64_t i = 0; i < rank_mesh->NumAxes(); ++i) { if (any_nd_sbp.sbp_parallel(i).has_split_parallel()) { split_num *= rank_mesh->At(i); } } - *ctx->OutputShape("y", 0) = Shape({split_num}); + *ctx->MutOutputShape("y", 0) = Shape({split_num}); return Maybe::Ok(); } Maybe InferLocalMultiReduceOpPhysicalShape(user_op::InferContext* ctx) { CHECK_GT_OR_RETURN(ctx->input_size("x"), 0) << ctx->op_name() << "must have at least 1 input"; - *ctx->OutputShape("y", 0) = Shape({1}); + *ctx->MutOutputShape("y", 0) = Shape({1}); return Maybe::Ok(); } diff --git a/oneflow/user/ops/narrow_op.cpp b/oneflow/user/ops/narrow_op.cpp index a8569c6784e..275041ad1a5 100644 --- a/oneflow/user/ops/narrow_op.cpp +++ b/oneflow/user/ops/narrow_op.cpp @@ -83,7 +83,7 @@ namespace oneflow { const int64_t ndim = dy_shape.NumAxes(); CHECK_EQ_OR_RETURN(like_shape.NumAxes(), ndim); - *ctx->OutputShape("dx", 0) = like_shape; + *ctx->MutOutputShape("dx", 0) = like_shape; return Maybe::Ok(); } diff --git a/oneflow/user/ops/nccl_logical_2d_sbp_ops.cpp b/oneflow/user/ops/nccl_logical_2d_sbp_ops.cpp index f8bf37f2771..13c39cd301e 100644 --- a/oneflow/user/ops/nccl_logical_2d_sbp_ops.cpp +++ b/oneflow/user/ops/nccl_logical_2d_sbp_ops.cpp @@ -23,7 +23,7 @@ namespace oneflow { /* static */ Maybe _ncclLogical_2DSameDim0AllReduceOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } @@ -65,7 +65,7 @@ namespace oneflow { /* static */ Maybe _ncclLogical_2DSameDim1AllReduceOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } @@ -107,7 +107,7 @@ namespace oneflow { /* static */ Maybe _ncclLogical_2DSameDim0AllGatherOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } @@ -150,7 +150,7 @@ namespace oneflow { /* static */ Maybe _ncclLogical_2DSameDim0AllGatherNoncontinuousOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } @@ -195,7 +195,7 @@ _ncclLogical_2DSameDim0AllGatherNoncontinuousOp::InferDeviceAndStream( /* static */ Maybe _ncclLogical_2DSameDim0All2allOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/nccl_logical_ops.cpp b/oneflow/user/ops/nccl_logical_ops.cpp index 5f157516389..54baf57426c 100644 --- a/oneflow/user/ops/nccl_logical_ops.cpp +++ b/oneflow/user/ops/nccl_logical_ops.cpp @@ -23,7 +23,7 @@ namespace oneflow { /* static */ Maybe _ncclLogicalAllReduceOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } @@ -62,7 +62,7 @@ namespace oneflow { /* static */ Maybe _ncclLogicalReduceScatterOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } @@ -103,7 +103,7 @@ namespace oneflow { /* static */ Maybe _ncclLogicalAllGatherOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } @@ -143,7 +143,7 @@ namespace oneflow { /* static */ Maybe _ncclLogicalAllGatherNoncontinuousOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } @@ -185,7 +185,7 @@ namespace oneflow { /* static */ Maybe _ncclLogicalReduceScatterNoncontinuousOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } @@ -230,7 +230,7 @@ namespace oneflow { } /* static */ Maybe _ncclLogicalS2sOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } @@ -269,7 +269,7 @@ namespace oneflow { /* static */ Maybe _ncclLogicalSendRecvOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/nd_index_slice_ops.cpp b/oneflow/user/ops/nd_index_slice_ops.cpp index 2fa17d2d390..bdbae09b336 100644 --- a/oneflow/user/ops/nd_index_slice_ops.cpp +++ b/oneflow/user/ops/nd_index_slice_ops.cpp @@ -42,7 +42,7 @@ Maybe InferScatterNdTensorDesc(user_op::InferContext* ctx) { const Shape& updates_shape = ctx->InputShape("updates", 0); const Shape& params_shape = ctx->Attr("shape"); JUST(CheckScatterNdShape(params_shape, indices_shape, updates_shape)); - *ctx->OutputShape("out", 0) = params_shape; + *ctx->MutOutputShape("out", 0) = params_shape; return Maybe::Ok(); } @@ -56,7 +56,7 @@ Maybe InferScatterNdLikeTensorDesc(user_op::InferContext* ctx) { const Shape& updates_shape = ctx->InputShape("updates", 0); const Shape& like_shape = ctx->InputShape("like", 0); JUST(CheckScatterNdShape(like_shape, indices_shape, updates_shape)); - *ctx->OutputShape("out", 0) = like_shape; + *ctx->MutOutputShape("out", 0) = like_shape; return Maybe::Ok(); } @@ -70,7 +70,7 @@ Maybe InferTensorScatterNdOptTensorDesc(user_op::InferContext* ctx) { const Shape& updates_shape = ctx->InputShape("updates", 0); const Shape& indices_shape = ctx->InputShape("indices", 0); JUST(CheckScatterNdShape(params_shape, indices_shape, updates_shape)); - *ctx->OutputShape("out", 0) = params_shape; + *ctx->MutOutputShape("out", 0) = params_shape; return Maybe::Ok(); } @@ -122,7 +122,7 @@ Maybe GetTensorScatterNdOptSbpSignatures(user_op::SbpContext* ctx) { FOR_RANGE(int64_t, i, index_ndims, params_shape.NumAxes()) { out_shape_vec.emplace_back(params_shape.At(i)); } - *ctx->OutputShape("out", 0) = Shape(out_shape_vec); + *ctx->MutOutputShape("out", 0) = Shape(out_shape_vec); return Maybe::Ok(); } diff --git a/oneflow/user/ops/nms_op.cpp b/oneflow/user/ops/nms_op.cpp index 1d9c0e29537..ea4d0a4c0f5 100644 --- a/oneflow/user/ops/nms_op.cpp +++ b/oneflow/user/ops/nms_op.cpp @@ -21,7 +21,7 @@ namespace oneflow { namespace { Maybe InferNmsTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = Shape({ctx->InputShape("in", 0).At(0)}); + *ctx->MutOutputShape("out", 0) = Shape({ctx->InputShape("in", 0).At(0)}); return Maybe::Ok(); } diff --git a/oneflow/user/ops/nvtx_range_op.cpp b/oneflow/user/ops/nvtx_range_op.cpp index 0f2bd54b2e6..c8d3509bc0f 100644 --- a/oneflow/user/ops/nvtx_range_op.cpp +++ b/oneflow/user/ops/nvtx_range_op.cpp @@ -22,7 +22,7 @@ namespace oneflow { #ifdef WITH_CUDA /* static */ Maybe NvtxStartOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } @@ -49,7 +49,7 @@ namespace oneflow { } /* static */ Maybe NvtxEndOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/one_embedding_ops.cpp b/oneflow/user/ops/one_embedding_ops.cpp index 99938d2d03d..5ac91d991d5 100644 --- a/oneflow/user/ops/one_embedding_ops.cpp +++ b/oneflow/user/ops/one_embedding_ops.cpp @@ -30,7 +30,7 @@ namespace oneflow { DimVector out_dim_vec = ids_shape.dim_vec(); const int64_t embedding_size = ctx->Attr("embedding_size"); out_dim_vec.push_back(embedding_size); - *ctx->OutputShape("embeddings", 0) = Shape(out_dim_vec); + *ctx->MutOutputShape("embeddings", 0) = Shape(out_dim_vec); return Maybe::Ok(); } @@ -116,7 +116,7 @@ REGISTER_USER_OP_GRAD("embedding_lookup_placeholder") CHECK_EQ_OR_RETURN(unique_ids_shape, table_ids_shape) << "table_ids shape must equal to ids shape"; CHECK_EQ_OR_RETURN(num_unique_ids_shape.elem_cnt(), 1); - *ctx->OutputShape("context", 0) = num_unique_ids_shape; + *ctx->MutOutputShape("context", 0) = num_unique_ids_shape; return Maybe::Ok(); } @@ -155,19 +155,19 @@ REGISTER_USER_OP_GRAD("embedding_lookup_placeholder") const bool use_dynamic_memory_allocation = embedding::UseDynamicMemoryAllocation(); if (ctx->has_output("embeddings", 0)) { if (use_dynamic_memory_allocation) { - *ctx->OutputShape("embeddings", 0) = Shape({1}); + *ctx->MutOutputShape("embeddings", 0) = Shape({1}); } else { DimVector embeddings_dim_vec = unique_ids_shape.dim_vec(); embeddings_dim_vec.push_back(embedding_size); - *ctx->OutputShape("embeddings", 0) = Shape(embeddings_dim_vec); + *ctx->MutOutputShape("embeddings", 0) = Shape(embeddings_dim_vec); } } if (use_dynamic_memory_allocation) { - *ctx->OutputShape("unique_values", 0) = Shape({1}); + *ctx->MutOutputShape("unique_values", 0) = Shape({1}); } else { DimVector unique_values_dim_vec = unique_ids_shape.dim_vec(); unique_values_dim_vec.push_back(line_size); - *ctx->OutputShape("unique_values", 0) = Shape(unique_values_dim_vec); + *ctx->MutOutputShape("unique_values", 0) = Shape(unique_values_dim_vec); } return Maybe::Ok(); @@ -318,7 +318,7 @@ Maybe GetEmbeddingUpdateSbp(user_op::SbpContext* ctx) { CHECK_NE_OR_RETURN(line_size, 0) << "should set attr line_size"; CHECK_EQ_OR_RETURN(line_size, embedding_size) << "get " << line_size << " " << embedding_size; const Shape& unique_embeddings_shape = ctx->InputShape("unique_embeddings", 0); - *ctx->OutputShape("updated_unique_embeddings", 0) = unique_embeddings_shape; + *ctx->MutOutputShape("updated_unique_embeddings", 0) = unique_embeddings_shape; return Maybe::Ok(); } @@ -346,7 +346,7 @@ Maybe GetEmbeddingUpdateSbp(user_op::SbpContext* ctx) { CHECK_NE_OR_RETURN(line_size, 0) << "should set attr line_size"; CHECK_EQ_OR_RETURN(line_size, embedding_size * 2) << "get " << line_size << " " << embedding_size; const Shape& unique_embeddings_shape = ctx->InputShape("unique_embeddings", 0); - *ctx->OutputShape("updated_unique_embeddings", 0) = unique_embeddings_shape; + *ctx->MutOutputShape("updated_unique_embeddings", 0) = unique_embeddings_shape; return Maybe::Ok(); } @@ -374,7 +374,7 @@ Maybe GetEmbeddingUpdateSbp(user_op::SbpContext* ctx) { CHECK_NE_OR_RETURN(line_size, 0) << "should set attr line_size"; CHECK_EQ_OR_RETURN(line_size, embedding_size * 3) << "get " << line_size << " " << embedding_size; const Shape& unique_embeddings_shape = ctx->InputShape("unique_embeddings", 0); - *ctx->OutputShape("updated_unique_embeddings", 0) = unique_embeddings_shape; + *ctx->MutOutputShape("updated_unique_embeddings", 0) = unique_embeddings_shape; return Maybe::Ok(); } @@ -402,7 +402,7 @@ Maybe GetEmbeddingUpdateSbp(user_op::SbpContext* ctx) { CHECK_NE_OR_RETURN(line_size, 0) << "should set attr line_size"; CHECK_EQ_OR_RETURN(line_size, embedding_size * 2) << "get " << line_size << " " << embedding_size; const Shape& unique_embeddings_shape = ctx->InputShape("unique_embeddings", 0); - *ctx->OutputShape("updated_unique_embeddings", 0) = unique_embeddings_shape; + *ctx->MutOutputShape("updated_unique_embeddings", 0) = unique_embeddings_shape; return Maybe::Ok(); } @@ -430,7 +430,7 @@ Maybe GetEmbeddingUpdateSbp(user_op::SbpContext* ctx) { CHECK_NE_OR_RETURN(line_size, 0) << "should set attr line_size"; CHECK_EQ_OR_RETURN(line_size, embedding_size * 3) << "get " << line_size << " " << embedding_size; const Shape& unique_embeddings_shape = ctx->InputShape("unique_embeddings", 0); - *ctx->OutputShape("updated_unique_embeddings", 0) = unique_embeddings_shape; + *ctx->MutOutputShape("updated_unique_embeddings", 0) = unique_embeddings_shape; return Maybe::Ok(); } diff --git a/oneflow/user/ops/ones_like_op.cpp b/oneflow/user/ops/ones_like_op.cpp index c64eefc2a0f..74f49c31590 100644 --- a/oneflow/user/ops/ones_like_op.cpp +++ b/oneflow/user/ops/ones_like_op.cpp @@ -33,8 +33,8 @@ namespace oneflow { return Maybe::Ok(); } /*static*/ Maybe OnesLikeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("like", 0); - *ctx->OutputStride("out", 0) = ctx->InputStride("like", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("like", 0); + *ctx->MutOutputStride("out", 0) = ctx->InputStride("like", 0); return Maybe::Ok(); } /*static*/ Maybe OnesLikeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { diff --git a/oneflow/user/ops/p2p_comm_op.cpp b/oneflow/user/ops/p2p_comm_op.cpp index 0c6998bdb87..1103106a736 100644 --- a/oneflow/user/ops/p2p_comm_op.cpp +++ b/oneflow/user/ops/p2p_comm_op.cpp @@ -48,7 +48,7 @@ Maybe> GetRecvOutputDeivce(user_op::DeviceAndStreamInferContext* /*static*/ Maybe RecvOp::GetSbp(user_op::SbpContext* ctx) { UNIMPLEMENTED_THEN_RETURN(); } /*static*/ Maybe RecvOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->Attr("shape"); + *ctx->MutOutputShape("out", 0) = ctx->Attr("shape"); return Maybe::Ok(); } /*static*/ Maybe RecvOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { diff --git a/oneflow/user/ops/pad_op.cpp b/oneflow/user/ops/pad_op.cpp index d1d020ed355..ce545d812f5 100644 --- a/oneflow/user/ops/pad_op.cpp +++ b/oneflow/user/ops/pad_op.cpp @@ -40,7 +40,7 @@ namespace oneflow { FOR_RANGE(int64_t, i, 0, x_shape.NumAxes()) { y_dim_vec[i] = x_shape.At(i) + padding_before[i] + padding_after[i]; } - *ctx->OutputShape("y", 0) = Shape(y_dim_vec); + *ctx->MutOutputShape("y", 0) = Shape(y_dim_vec); return Maybe::Ok(); } /*static*/ Maybe PadOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { diff --git a/oneflow/user/ops/padding_ops.cpp b/oneflow/user/ops/padding_ops.cpp index 41ef1da54ea..400f846e9fd 100644 --- a/oneflow/user/ops/padding_ops.cpp +++ b/oneflow/user/ops/padding_ops.cpp @@ -74,7 +74,7 @@ Maybe GetOpGradSbpSignature(user_op::SbpContext* ctx) { y_dim_vec[h_idx] = h_x + padding[2] + padding[3]; y_dim_vec[w_idx] = w_x + padding[0] + padding[1]; - *ctx->OutputShape("y", 0) = Shape(y_dim_vec); + *ctx->MutOutputShape("y", 0) = Shape(y_dim_vec); return Maybe::Ok(); } /*static*/ Maybe ReflectionPad2DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { @@ -113,7 +113,7 @@ Maybe GetOpGradSbpSignature(user_op::SbpContext* ctx) { dx_dim_vec[h_idx] = h_dy - padding[2] - padding[3]; dx_dim_vec[w_idx] = w_dy - padding[0] - padding[1]; - *ctx->OutputShape("dx", 0) = Shape(dx_dim_vec); + *ctx->MutOutputShape("dx", 0) = Shape(dx_dim_vec); return Maybe::Ok(); } /*static*/ Maybe ReflectionPad2DGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { @@ -162,7 +162,7 @@ REGISTER_USER_OP_GRAD("reflection_pad2d") y_dim_vec[h_idx] = h_x + padding[2] + padding[3]; y_dim_vec[w_idx] = w_x + padding[0] + padding[1]; - *ctx->OutputShape("y", 0) = Shape(y_dim_vec); + *ctx->MutOutputShape("y", 0) = Shape(y_dim_vec); return Maybe::Ok(); } /*static*/ Maybe ReplicationPad2DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { @@ -201,7 +201,7 @@ REGISTER_USER_OP_GRAD("reflection_pad2d") dx_dim_vec[h_idx] = h_dy - padding[2] - padding[3]; dx_dim_vec[w_idx] = w_dy - padding[0] - padding[1]; - *ctx->OutputShape("dx", 0) = Shape(dx_dim_vec); + *ctx->MutOutputShape("dx", 0) = Shape(dx_dim_vec); return Maybe::Ok(); } /*static*/ Maybe ReplicationPad2DGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { diff --git a/oneflow/user/ops/parallel_cast_op.cpp b/oneflow/user/ops/parallel_cast_op.cpp index 9d25b9504de..e24f264cd8a 100644 --- a/oneflow/user/ops/parallel_cast_op.cpp +++ b/oneflow/user/ops/parallel_cast_op.cpp @@ -23,7 +23,7 @@ namespace oneflow { return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); } /*static*/ Maybe ParallelCastOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/partial_fc_sample_op.cpp b/oneflow/user/ops/partial_fc_sample_op.cpp index 1798e91fe6d..9ca056933aa 100644 --- a/oneflow/user/ops/partial_fc_sample_op.cpp +++ b/oneflow/user/ops/partial_fc_sample_op.cpp @@ -111,11 +111,11 @@ namespace oneflow { } /*static*/ Maybe DistributedPartialFcSampleDisableBoxingOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { - *ctx->OutputShape("boxing_disabled_sampled_weight_diff", 0) = + *ctx->MutOutputShape("boxing_disabled_sampled_weight_diff", 0) = ctx->InputShape("sampled_weight_diff", 0); *ctx->OutputIsDynamic("boxing_disabled_sampled_weight_diff", 0) = ctx->InputIsDynamic("sampled_weight_diff", 0); - *ctx->OutputShape("boxing_disabled_sampled_label", 0) = ctx->InputShape("sampled_label", 0); + *ctx->MutOutputShape("boxing_disabled_sampled_label", 0) = ctx->InputShape("sampled_label", 0); *ctx->OutputIsDynamic("boxing_disabled_sampled_label", 0) = ctx->InputIsDynamic("sampled_label", 0); return Maybe::Ok(); diff --git a/oneflow/user/ops/prelu_op.cpp b/oneflow/user/ops/prelu_op.cpp index 6cd352ba5ba..1b19189f328 100644 --- a/oneflow/user/ops/prelu_op.cpp +++ b/oneflow/user/ops/prelu_op.cpp @@ -40,7 +40,7 @@ namespace oneflow { } /*static*/ Maybe PreluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); - Shape* y_shape = ctx->OutputShape("y", 0); + Shape* y_shape = ctx->MutOutputShape("y", 0); const Shape& alpha_shape = ctx->InputShape("alpha", 0); CHECK_EQ_OR_RETURN(alpha_shape.NumAxes(), 1); *y_shape = x_shape; @@ -91,8 +91,8 @@ namespace oneflow { /*static*/ Maybe PreluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); - Shape* alpha_diff_shape = ctx->OutputShape("alpha_diff", 0); + Shape* dx_shape = ctx->MutOutputShape("dx", 0); + Shape* alpha_diff_shape = ctx->MutOutputShape("alpha_diff", 0); const Shape& alpha_shape = ctx->InputShape("alpha", 0); CHECK_EQ_OR_RETURN(alpha_shape.NumAxes(), 1); CHECK_OR_RETURN((alpha_shape.At(0) == x_shape.At(1)) || (alpha_shape.At(0) == 1)); diff --git a/oneflow/user/ops/quantization_op.cpp b/oneflow/user/ops/quantization_op.cpp index 2396a1a1685..759b65472bf 100644 --- a/oneflow/user/ops/quantization_op.cpp +++ b/oneflow/user/ops/quantization_op.cpp @@ -68,7 +68,7 @@ namespace oneflow { CHECK_EQ_OR_RETURN(zero_point_shape.elem_cnt(), in_shape.At(0)); } - *ctx->OutputShape("out", 0) = in_shape; + *ctx->MutOutputShape("out", 0) = in_shape; return Maybe::Ok(); } /*static*/ Maybe QuantizationOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { diff --git a/oneflow/user/ops/randperm_op.cpp b/oneflow/user/ops/randperm_op.cpp index 956902154ae..7075f37327d 100644 --- a/oneflow/user/ops/randperm_op.cpp +++ b/oneflow/user/ops/randperm_op.cpp @@ -27,7 +27,7 @@ namespace oneflow { } /*static*/ Maybe RandpermOp::GetSbp(user_op::SbpContext* ctx) { return Maybe::Ok(); } /*static*/ Maybe RandpermOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - Shape* out_shape = ctx->OutputShape("out", 0); + Shape* out_shape = ctx->MutOutputShape("out", 0); int32_t n = ctx->Attr("n"); CHECK_GE_OR_RETURN(n, 0) << Error::RuntimeError() << "Trying to create tensor with negative dimension " << n << ":" @@ -45,7 +45,7 @@ namespace oneflow { GetTensorSliceView4ParallelId(parallel_hierarchy, nd_sbp, logical_shape, parallel_id); const Shape& physical_shape = tensor_slice_view.shape(); - *ctx->OutputShape("out", 0) = physical_shape; + *ctx->MutOutputShape("out", 0) = physical_shape; return Maybe::Ok(); } diff --git a/oneflow/user/ops/reduce_ops.cpp b/oneflow/user/ops/reduce_ops.cpp index 5ac0a70038c..fbfcff77d8f 100644 --- a/oneflow/user/ops/reduce_ops.cpp +++ b/oneflow/user/ops/reduce_ops.cpp @@ -23,8 +23,8 @@ namespace oneflow { Maybe InferTensorDescFn(user_op::InferContext* ctx) { const Shape& input_shape = ctx->InputShape("input_tensor", 0); const auto& reduce_axes = ctx->Attr>("axis"); - Shape* output_shape = ctx->OutputShape("output_tensor", 0); - Stride* output_stride = ctx->OutputStride("output_tensor", 0); + Shape* output_shape = ctx->MutOutputShape("output_tensor", 0); + Stride* output_stride = ctx->MutOutputStride("output_tensor", 0); // For 0-dim Tensor if (reduce_axes.empty()) { *output_shape = input_shape; diff --git a/oneflow/user/ops/relu_op.cpp b/oneflow/user/ops/relu_op.cpp index 38e4f58328a..6b87f2fd4c0 100644 --- a/oneflow/user/ops/relu_op.cpp +++ b/oneflow/user/ops/relu_op.cpp @@ -27,7 +27,7 @@ namespace oneflow { } /*static*/ Maybe ReluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& in_shape = ctx->InputShape("x", 0); - Shape* out_shape = ctx->OutputShape("y", 0); + Shape* out_shape = ctx->MutOutputShape("y", 0); *out_shape = in_shape; return Maybe::Ok(); } @@ -53,7 +53,7 @@ namespace oneflow { /*static*/ Maybe ReluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& y_shape = ctx->InputShape("y", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); + Shape* dx_shape = ctx->MutOutputShape("dx", 0); CHECK_OR_RETURN(dy_shape == y_shape) << Error::RuntimeError() << "Tensors y and dy must have the same shape"; *dx_shape = dy_shape; diff --git a/oneflow/user/ops/repeat_op.cpp b/oneflow/user/ops/repeat_op.cpp index 60b281854dc..2f00322b3a2 100644 --- a/oneflow/user/ops/repeat_op.cpp +++ b/oneflow/user/ops/repeat_op.cpp @@ -31,7 +31,7 @@ namespace oneflow { return Maybe::Ok(); } /*static*/ Maybe RepeatOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/reshape_like_op.cpp b/oneflow/user/ops/reshape_like_op.cpp index 7b11d6de6f0..e40cab51ebd 100644 --- a/oneflow/user/ops/reshape_like_op.cpp +++ b/oneflow/user/ops/reshape_like_op.cpp @@ -44,7 +44,7 @@ namespace oneflow { << "The element number of the in tensor must be equal to the element number of the " "like tensor, " << "but got " << in_shape.elem_cnt() << " and " << like_shape.elem_cnt(); - *ctx->OutputShape("out", 0) = like_shape; + *ctx->MutOutputShape("out", 0) = like_shape; return Maybe::Ok(); } /*static*/ Maybe ReshapeLikeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { diff --git a/oneflow/user/ops/roi_align_op.cpp b/oneflow/user/ops/roi_align_op.cpp index c2a45e6eedc..090c29674a5 100644 --- a/oneflow/user/ops/roi_align_op.cpp +++ b/oneflow/user/ops/roi_align_op.cpp @@ -37,7 +37,7 @@ namespace oneflow { CHECK_EQ(rois_shape.NumAxes(), 2); CHECK_EQ(rois_shape.At(1), 5); // y: (R, C, pool_h, pool_w) - *ctx->OutputShape("y", 0) = Shape({rois_shape.At(0), x_shape.At(1), pooled_h, pooled_w}); + *ctx->MutOutputShape("y", 0) = Shape({rois_shape.At(0), x_shape.At(1), pooled_h, pooled_w}); return Maybe::Ok(); } /*static*/ Maybe RoiAlignOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { @@ -81,7 +81,7 @@ namespace oneflow { // y: (R, C, pool_h, pool_w) const Shape& y_shape = Shape({rois_shape.At(0), x_like_shape.At(1), pooled_h, pooled_w}); CHECK_EQ_OR_RETURN(y_shape, dy_shape); - *ctx->OutputShape("dx", 0) = x_like_shape; + *ctx->MutOutputShape("dx", 0) = x_like_shape; return Maybe::Ok(); } /*static*/ Maybe RoiAlignGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { diff --git a/oneflow/user/ops/roll_op.cpp b/oneflow/user/ops/roll_op.cpp index b07077d814b..01fd2742c3b 100644 --- a/oneflow/user/ops/roll_op.cpp +++ b/oneflow/user/ops/roll_op.cpp @@ -45,7 +45,7 @@ namespace oneflow { } /*static*/ Maybe RollOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& in_shape = ctx->InputShape("in", 0); - *ctx->OutputShape("out", 0) = in_shape; + *ctx->MutOutputShape("out", 0) = in_shape; return Maybe::Ok(); } /*static*/ Maybe RollOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { diff --git a/oneflow/user/ops/same_padding_op.cpp b/oneflow/user/ops/same_padding_op.cpp index 267faf5fecf..40ca7ccd3f9 100644 --- a/oneflow/user/ops/same_padding_op.cpp +++ b/oneflow/user/ops/same_padding_op.cpp @@ -108,7 +108,7 @@ namespace oneflow { return Maybe::Ok(); } /*static*/ Maybe SamePaddingGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("dx", 0) = ctx->InputShape("x_like", 0); + *ctx->MutOutputShape("dx", 0) = ctx->InputShape("x_like", 0); *ctx->OutputIsDynamic("dx", 0) = ctx->InputIsDynamic("x_like", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/scalar_logical_op.cpp b/oneflow/user/ops/scalar_logical_op.cpp index 8c0786c2804..a242b67f924 100644 --- a/oneflow/user/ops/scalar_logical_op.cpp +++ b/oneflow/user/ops/scalar_logical_op.cpp @@ -27,7 +27,7 @@ namespace oneflow { return Maybe::Ok(); \ } \ /*static*/ Maybe name##Op::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); \ + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); \ *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); \ return Maybe::Ok(); \ } \ diff --git a/oneflow/user/ops/scalar_math_op.cpp b/oneflow/user/ops/scalar_math_op.cpp index 3627acde3cf..6712023f60c 100644 --- a/oneflow/user/ops/scalar_math_op.cpp +++ b/oneflow/user/ops/scalar_math_op.cpp @@ -42,7 +42,7 @@ Maybe GetSbp4ScalarMul(user_op::SbpContext* ctx) { #define IMPLEMENT_SCALAR_MATH_OP_FUNCS(op_name, get_sbp_fn) \ /*static*/ Maybe op_name##Op::GetSbp(user_op::SbpContext* ctx) { return get_sbp_fn(ctx); } \ /*static*/ Maybe op_name##Op::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); \ + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); \ *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); \ return Maybe::Ok(); \ } \ @@ -71,7 +71,7 @@ IMPLEMENT_SCALAR_MATH_OP_FUNCS(ScalarReversePow, GetSbp4ScalarMath) return Maybe::Ok(); } /*static*/ Maybe ScalarPowGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("dx", 0) = ctx->InputShape("x", 0); + *ctx->MutOutputShape("dx", 0) = ctx->InputShape("x", 0); return Maybe::Ok(); } /*static*/ Maybe ScalarPowGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { @@ -92,7 +92,7 @@ IMPLEMENT_SCALAR_MATH_OP_FUNCS(ScalarReversePow, GetSbp4ScalarMath) return Maybe::Ok(); } /*static*/ Maybe ScalarReversePowGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("dx", 0) = ctx->InputShape("x", 0); + *ctx->MutOutputShape("dx", 0) = ctx->InputShape("x", 0); return Maybe::Ok(); } /*static*/ Maybe ScalarReversePowGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { diff --git a/oneflow/user/ops/search_sorted_op.cpp b/oneflow/user/ops/search_sorted_op.cpp index 368114c17ec..1a96a0a9ccb 100644 --- a/oneflow/user/ops/search_sorted_op.cpp +++ b/oneflow/user/ops/search_sorted_op.cpp @@ -19,7 +19,7 @@ limitations under the License. namespace oneflow { /* static */ Maybe SearchSortedOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("values", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("values", 0); return Maybe::Ok(); } @@ -54,7 +54,7 @@ namespace oneflow { } /* static */ Maybe SearchSortedScalarOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = Shape({}); + *ctx->MutOutputShape("out", 0) = Shape({}); return Maybe::Ok(); } diff --git a/oneflow/user/ops/selu_op.cpp b/oneflow/user/ops/selu_op.cpp index e23a95c8526..cb0de53192e 100644 --- a/oneflow/user/ops/selu_op.cpp +++ b/oneflow/user/ops/selu_op.cpp @@ -26,7 +26,7 @@ namespace oneflow { return Maybe::Ok(); } /*static*/ Maybe SeluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); return Maybe::Ok(); } /*static*/ Maybe SeluOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { @@ -51,7 +51,7 @@ namespace oneflow { /*static*/ Maybe SeluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); + Shape* dx_shape = ctx->MutOutputShape("dx", 0); CHECK_OR_RETURN(dy_shape == x_shape) << Error::RuntimeError() << "Tensors dy and x must be the same shape"; *dx_shape = dy_shape; diff --git a/oneflow/user/ops/silu_op.cpp b/oneflow/user/ops/silu_op.cpp index 8e35ae69ab1..cc459d2a605 100644 --- a/oneflow/user/ops/silu_op.cpp +++ b/oneflow/user/ops/silu_op.cpp @@ -26,7 +26,7 @@ namespace oneflow { return Maybe::Ok(); } /*static*/ Maybe SiluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); return Maybe::Ok(); } /*static*/ Maybe SiluOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { @@ -51,7 +51,7 @@ namespace oneflow { /*static*/ Maybe SiluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); + Shape* dx_shape = ctx->MutOutputShape("dx", 0); CHECK_OR_RETURN(dy_shape == x_shape) << Error::RuntimeError() << "The size of dy " << dy_shape << " must match the size of x " << x_shape; *dx_shape = dy_shape; diff --git a/oneflow/user/ops/slice_op.cpp b/oneflow/user/ops/slice_op.cpp index 3ae88200258..c0b7bea6caa 100644 --- a/oneflow/user/ops/slice_op.cpp +++ b/oneflow/user/ops/slice_op.cpp @@ -170,7 +170,7 @@ bool IsFullSlice(int64_t start, int64_t stop, int64_t step, int64_t size) { const int64_t diff = stop - start - 1; dim_vec[i] = diff / step + 1; } - *ctx->OutputShape("y", 0) = Shape(dim_vec); + *ctx->MutOutputShape("y", 0) = Shape(dim_vec); return Maybe::Ok(); } /*static*/ Maybe SliceOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { @@ -198,7 +198,7 @@ bool IsFullSlice(int64_t start, int64_t stop, int64_t step, int64_t size) { const int64_t parallel_id = ctx->parallel_ctx().parallel_id(); const TensorSliceView& slice_view = GetTensorSliceView4ParallelId(parallel_hierarchy, y_nd_sbp, logical_shape, parallel_id); - *ctx->OutputShape("y", 0) = Shape(slice_view.shape()); + *ctx->MutOutputShape("y", 0) = Shape(slice_view.shape()); return Maybe::Ok(); } /*static*/ Maybe SliceOp::InferDataType(user_op::InferContext* ctx) { @@ -253,7 +253,7 @@ bool IsFullSlice(int64_t start, int64_t stop, int64_t step, int64_t size) { << Error::RuntimeError() << "The size of step list must be equal to the dimension of ref tensor, " << "but got " << step_vec.size() << " and " << ndim; - *ctx->OutputShape("dx", 0) = like_shape; + *ctx->MutOutputShape("dx", 0) = like_shape; return Maybe::Ok(); } /*static*/ Maybe SliceGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { diff --git a/oneflow/user/ops/softmax_cross_entropy_op.cpp b/oneflow/user/ops/softmax_cross_entropy_op.cpp index 1b31f895407..f193e333c5d 100644 --- a/oneflow/user/ops/softmax_cross_entropy_op.cpp +++ b/oneflow/user/ops/softmax_cross_entropy_op.cpp @@ -51,7 +51,7 @@ namespace oneflow { FOR_RANGE(int64_t, i, 0, num_out_axes) { out_dim_vector.emplace_back(prediction_desc.shape().At(i)); } - *ctx->OutputShape("prob", 0) = ctx->InputShape("prediction", 0); + *ctx->MutOutputShape("prob", 0) = ctx->InputShape("prediction", 0); *ctx->OutputIsDynamic("prob", 0) = ctx->InputIsDynamic("prediction", 0); user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); *out_desc->mut_is_dynamic() = prediction_desc.is_dynamic(); @@ -118,7 +118,7 @@ namespace oneflow { CHECK_EQ_OR_RETURN(label_desc.shape(), prob_desc.shape()) << Error::RuntimeError() << "The size of label " << label_desc.shape() << " must match the size of prob " << prob_desc.shape(); - *ctx->OutputShape("prediction_diff", 0) = ctx->InputShape("prob", 0); + *ctx->MutOutputShape("prediction_diff", 0) = ctx->InputShape("prob", 0); *ctx->OutputIsDynamic("prediction_diff", 0) = ctx->InputIsDynamic("prob", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/softmax_op.cpp b/oneflow/user/ops/softmax_op.cpp index a726d561073..4dfc29ad88d 100644 --- a/oneflow/user/ops/softmax_op.cpp +++ b/oneflow/user/ops/softmax_op.cpp @@ -29,7 +29,7 @@ namespace oneflow { return Maybe::Ok(); } /*static*/ Maybe SoftmaxOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); return Maybe::Ok(); } /*static*/ Maybe SoftmaxOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { @@ -54,7 +54,7 @@ namespace oneflow { /*static*/ Maybe SoftmaxGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& y_shape = ctx->InputShape("y", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); + Shape* dx_shape = ctx->MutOutputShape("dx", 0); CHECK_OR_RETURN(dy_shape == y_shape) << Error::RuntimeError() << "The size of dy " << dy_shape << " must match the size of y " << y_shape; *dx_shape = dy_shape; diff --git a/oneflow/user/ops/softplus_op.cpp b/oneflow/user/ops/softplus_op.cpp index 2a772b661c0..18ec0cfc439 100644 --- a/oneflow/user/ops/softplus_op.cpp +++ b/oneflow/user/ops/softplus_op.cpp @@ -19,7 +19,7 @@ limitations under the License. namespace oneflow { /* static */ Maybe SoftplusOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); return Maybe::Ok(); } @@ -43,7 +43,7 @@ namespace oneflow { /* static */ Maybe SoftplusGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); + Shape* dx_shape = ctx->MutOutputShape("dx", 0); CHECK_OR_RETURN(dy_shape == x_shape) << Error::RuntimeError() << "The size of dy " << dy_shape << " must match the size of x " << x_shape; *dx_shape = dy_shape; diff --git a/oneflow/user/ops/softshrink_op.cpp b/oneflow/user/ops/softshrink_op.cpp index 95ec290270b..3bed51333d4 100644 --- a/oneflow/user/ops/softshrink_op.cpp +++ b/oneflow/user/ops/softshrink_op.cpp @@ -19,7 +19,7 @@ limitations under the License. namespace oneflow { /* static */ Maybe SoftShrinkOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); return Maybe::Ok(); } @@ -43,7 +43,7 @@ namespace oneflow { /* static */ Maybe SoftShrinkGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& y_shape = ctx->InputShape("y", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); + Shape* dx_shape = ctx->MutOutputShape("dx", 0); CHECK_OR_RETURN(dy_shape == y_shape) << Error::RuntimeError() << "The size of dy " << dy_shape << " must match the size of y " << y_shape; *dx_shape = dy_shape; diff --git a/oneflow/user/ops/softsign_op.cpp b/oneflow/user/ops/softsign_op.cpp index 61e45f781e6..2b474b67f19 100644 --- a/oneflow/user/ops/softsign_op.cpp +++ b/oneflow/user/ops/softsign_op.cpp @@ -26,7 +26,7 @@ namespace oneflow { return Maybe::Ok(); } /*static*/ Maybe SoftsignOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); return Maybe::Ok(); } /*static*/ Maybe SoftsignOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { @@ -51,7 +51,7 @@ namespace oneflow { /*static*/ Maybe SoftsignGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); + Shape* dx_shape = ctx->MutOutputShape("dx", 0); CHECK_OR_RETURN(dy_shape == x_shape) << Error::RuntimeError() << "The size of dy " << dy_shape << " must match the size of x " << x_shape; *dx_shape = dy_shape; diff --git a/oneflow/user/ops/sort_op.cpp b/oneflow/user/ops/sort_op.cpp index f2dd5e6f89b..5c3add243b3 100644 --- a/oneflow/user/ops/sort_op.cpp +++ b/oneflow/user/ops/sort_op.cpp @@ -28,7 +28,7 @@ namespace oneflow { return Maybe::Ok(); } /*static*/ Maybe SortOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); return Maybe::Ok(); } /*static*/ Maybe SortOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { diff --git a/oneflow/user/ops/sparse_cross_entropy_op.cpp b/oneflow/user/ops/sparse_cross_entropy_op.cpp index b661910fe8c..adce0aa9b7f 100644 --- a/oneflow/user/ops/sparse_cross_entropy_op.cpp +++ b/oneflow/user/ops/sparse_cross_entropy_op.cpp @@ -62,7 +62,7 @@ Maybe InferGradTensorDescFn(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(dy_desc.shape(), label_desc.shape()) << Error::RuntimeError() << "The size of dy " << dy_desc.shape() << " must match the size of label " << label_desc.shape(); - *ctx->OutputShape("prediction_diff", 0) = prediction_desc.shape(); + *ctx->MutOutputShape("prediction_diff", 0) = prediction_desc.shape(); *ctx->OutputIsDynamic("prediction_diff", 0) = prediction_desc.is_dynamic(); return Maybe::Ok(); } diff --git a/oneflow/user/ops/sparse_softmax_cross_entropy_op.cpp b/oneflow/user/ops/sparse_softmax_cross_entropy_op.cpp index 0d77af3f218..7e02cb9fd23 100644 --- a/oneflow/user/ops/sparse_softmax_cross_entropy_op.cpp +++ b/oneflow/user/ops/sparse_softmax_cross_entropy_op.cpp @@ -43,7 +43,7 @@ Maybe InferTensorDescFn(user_op::InferContext* ctx) { } *ctx->OutputIsDynamic("prob", 0) = prediction_desc.is_dynamic(); // 'prob' is just for compute prediction's grad, prob's grad will be ignored - *ctx->OutputShape("prob", 0) = prediction_desc.shape(); + *ctx->MutOutputShape("prob", 0) = prediction_desc.shape(); user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); *out_desc->mut_is_dynamic() = prediction_desc.is_dynamic(); *out_desc->mut_shape() = label_desc.shape(); @@ -75,7 +75,7 @@ Maybe InferGradTensorDescFn(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(dy_desc.shape(), label_desc.shape()) << Error::RuntimeError() << "The size of dy " << dy_desc.shape() << " must match the size of label " << label_desc.shape(); - *ctx->OutputShape("prediction_diff", 0) = prob_desc.shape(); + *ctx->MutOutputShape("prediction_diff", 0) = prob_desc.shape(); *ctx->OutputIsDynamic("prediction_diff", 0) = prob_desc.is_dynamic(); return Maybe::Ok(); } diff --git a/oneflow/user/ops/squeeze_op.cpp b/oneflow/user/ops/squeeze_op.cpp index d6c9cb111a4..5fe2422a6a8 100644 --- a/oneflow/user/ops/squeeze_op.cpp +++ b/oneflow/user/ops/squeeze_op.cpp @@ -63,7 +63,7 @@ Maybe CheckAndLabelAxesToSqueezeMinusOne(const AxisVector& axes, DimVector } /*static*/ Maybe SqueezeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& in_shape = ctx->InputShape("in", 0); - Shape* out_shape = ctx->OutputShape("out", 0); + Shape* out_shape = ctx->MutOutputShape("out", 0); AxisVector fixed_axes_vec; JUST(TransformNegativeAxesToPositive(ctx->Attr>("axes"), in_shape.NumAxes(), &fixed_axes_vec)); diff --git a/oneflow/user/ops/ssp_variable_proxy_op.cpp b/oneflow/user/ops/ssp_variable_proxy_op.cpp index 9a5a31262a7..00299abcd86 100644 --- a/oneflow/user/ops/ssp_variable_proxy_op.cpp +++ b/oneflow/user/ops/ssp_variable_proxy_op.cpp @@ -31,8 +31,8 @@ namespace oneflow { } /*static*/ Maybe SspVariableProxyOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& var_shape = ctx->InputShape("var", 0); - *ctx->OutputShape("ref", 0) = var_shape; - *ctx->OutputShape("value", 0) = var_shape; + *ctx->MutOutputShape("ref", 0) = var_shape; + *ctx->MutOutputShape("value", 0) = var_shape; return Maybe::Ok(); } /*static*/ Maybe SspVariableProxyOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { diff --git a/oneflow/user/ops/tf_pool_op.cpp b/oneflow/user/ops/tf_pool_op.cpp index 39afc8478b8..73a6ab3380e 100644 --- a/oneflow/user/ops/tf_pool_op.cpp +++ b/oneflow/user/ops/tf_pool_op.cpp @@ -51,7 +51,7 @@ TensorDescInferFn MakeFwTensorDescInferFn(const int32_t dim) { } Maybe BwTensorDescInferFn(user_op::InferContext* ctx) { - *ctx->OutputShape("dx", 0) = ctx->InputShape("x", 0); + *ctx->MutOutputShape("dx", 0) = ctx->InputShape("x", 0); *ctx->OutputIsDynamic("dx", 0) = ctx->InputIsDynamic("x", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/tf_prelu_op.cpp b/oneflow/user/ops/tf_prelu_op.cpp index b4880e201e7..f183d82e607 100644 --- a/oneflow/user/ops/tf_prelu_op.cpp +++ b/oneflow/user/ops/tf_prelu_op.cpp @@ -102,7 +102,7 @@ namespace oneflow { CHECK_EQ_OR_RETURN(dy_desc.data_type(), x_desc.data_type()); *dx_desc->mut_shape() = x_desc.shape(); *dx_desc->mut_is_dynamic() = x_desc.is_dynamic(); - *ctx->OutputShape("alpha_diff", 0) = alpha_desc.shape(); + *ctx->MutOutputShape("alpha_diff", 0) = alpha_desc.shape(); *ctx->OutputIsDynamic("alpha_diff", 0) = alpha_desc.is_dynamic(); return Maybe::Ok(); } diff --git a/oneflow/user/ops/threshold_op.cpp b/oneflow/user/ops/threshold_op.cpp index 3cf10ab9dae..f2ad58f111f 100644 --- a/oneflow/user/ops/threshold_op.cpp +++ b/oneflow/user/ops/threshold_op.cpp @@ -19,7 +19,7 @@ limitations under the License. namespace oneflow { /* static */ Maybe ThresholdOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); return Maybe::Ok(); } @@ -43,7 +43,7 @@ namespace oneflow { /* static */ Maybe ThresholdGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); + Shape* dx_shape = ctx->MutOutputShape("dx", 0); CHECK_OR_RETURN(dy_shape == x_shape); *dx_shape = dy_shape; return Maybe::Ok(); diff --git a/oneflow/user/ops/to_contiguous_op.cpp b/oneflow/user/ops/to_contiguous_op.cpp index 95a80c3e1b6..09ce23959f8 100644 --- a/oneflow/user/ops/to_contiguous_op.cpp +++ b/oneflow/user/ops/to_contiguous_op.cpp @@ -24,8 +24,8 @@ namespace oneflow { } /*static*/ Maybe ToContiguousOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputStride("out", 0) = Stride(in_desc.shape()); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputStride("out", 0) = Stride(in_desc.shape()); return Maybe::Ok(); } /*static*/ Maybe ToContiguousOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { diff --git a/oneflow/user/ops/top_k_op.cpp b/oneflow/user/ops/top_k_op.cpp index 0bcf295d5bd..c41051e8252 100644 --- a/oneflow/user/ops/top_k_op.cpp +++ b/oneflow/user/ops/top_k_op.cpp @@ -29,7 +29,7 @@ namespace oneflow { } /*static*/ Maybe TopKOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& in_shape = ctx->InputShape("in", 0); - Shape* out_shape = ctx->OutputShape("out", 0); + Shape* out_shape = ctx->MutOutputShape("out", 0); *out_shape = in_shape; out_shape->Set(in_shape.NumAxes() - 1, std::min(ctx->Attr("k"), static_cast(in_shape.dim_vec().back()))); diff --git a/oneflow/user/ops/tuple_identity_op.cpp b/oneflow/user/ops/tuple_identity_op.cpp index dd98f2fef74..7e2631989d0 100644 --- a/oneflow/user/ops/tuple_identity_op.cpp +++ b/oneflow/user/ops/tuple_identity_op.cpp @@ -26,7 +26,7 @@ namespace oneflow { const int64_t in_size = ctx->input_size("in"); CHECK_EQ_OR_RETURN(ctx->output_size("out"), in_size); for (int64_t i = 0; i < in_size; ++i) { - *ctx->OutputShape("out", i) = ctx->InputShape("in", i); + *ctx->MutOutputShape("out", i) = ctx->InputShape("in", i); *ctx->IsDynamic4ArgNameAndIndex("out", i) = ctx->InputIsDynamic("in", i); } return Maybe::Ok(); diff --git a/oneflow/user/ops/two_stage_reduce_ops.cpp b/oneflow/user/ops/two_stage_reduce_ops.cpp index 9fbb79e1da0..0c65508c8b6 100644 --- a/oneflow/user/ops/two_stage_reduce_ops.cpp +++ b/oneflow/user/ops/two_stage_reduce_ops.cpp @@ -33,7 +33,7 @@ Maybe InferReduceDeviceStageLogicalTensorDescFn(user_op::InferContext* ctx const Shape& input_shape = ctx->InputShape("in", 0); const auto& axis = ctx->Attr>("axis"); const int64_t num_axes = input_shape.NumAxes(); - Shape* output_shape = ctx->OutputShape("out", 0); + Shape* output_shape = ctx->MutOutputShape("out", 0); if (axis.empty()) { *output_shape = Shape::Ones(num_axes); } else { @@ -63,8 +63,8 @@ Maybe InferReduceDeviceStageLogicalTensorDescFn(user_op::InferContext* ctx *output_shape = Shape(dim_vec); } - *ctx->OutputShape("mask", 0) = input_shape; - *ctx->OutputShape("count", 0) = *output_shape; + *ctx->MutOutputShape("mask", 0) = input_shape; + *ctx->MutOutputShape("count", 0) = *output_shape; return Maybe::Ok(); } @@ -72,7 +72,7 @@ Maybe InferReduceDeviceStageLogicalTensorDescFn(user_op::InferContext* ctx Maybe InferReduceDeviceStagePhysicalTensorDescFn(user_op::InferContext* ctx) { const Shape& input_shape = ctx->InputShape("in", 0); const auto& axis = ctx->Attr>("axis"); - Shape* output_shape = ctx->OutputShape("out", 0); + Shape* output_shape = ctx->MutOutputShape("out", 0); if (axis.empty()) { *output_shape = Shape::Ones(input_shape.NumAxes()); } else { @@ -81,8 +81,8 @@ Maybe InferReduceDeviceStagePhysicalTensorDescFn(user_op::InferContext* ct *output_shape = reduced_shape; } - *ctx->OutputShape("mask", 0) = input_shape; - *ctx->OutputShape("count", 0) = *output_shape; + *ctx->MutOutputShape("mask", 0) = input_shape; + *ctx->MutOutputShape("count", 0) = *output_shape; return Maybe::Ok(); } @@ -96,7 +96,7 @@ Maybe InferReduceDeviceStageGradDtypeFn(user_op::InferContext* ctx) { Maybe InferReduceDeviceStageGradTensorDescFn(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(ctx->InputShape("out_diff", 0), ctx->InputShape("count", 0)); - *ctx->OutputShape("in_diff", 0) = ctx->InputShape("mask", 0); + *ctx->MutOutputShape("in_diff", 0) = ctx->InputShape("mask", 0); return Maybe::Ok(); } @@ -114,7 +114,7 @@ Maybe InferReduceGlobalStageTensorDescFn(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(input_shape, device_count_shape); const auto& axis = ctx->Attr>("axis"); bool keepdims = ctx->Attr("keepdims"); - Shape* output_shape = ctx->OutputShape("out", 0); + Shape* output_shape = ctx->MutOutputShape("out", 0); if (axis.empty()) { if (keepdims) { *output_shape = Shape::Ones(input_shape.NumAxes()); @@ -131,7 +131,7 @@ Maybe InferReduceGlobalStageTensorDescFn(user_op::InferContext* ctx) { } } - *ctx->OutputShape("mask", 0) = input_shape; + *ctx->MutOutputShape("mask", 0) = input_shape; return Maybe::Ok(); } @@ -149,7 +149,7 @@ Maybe InferReduceGlobalStageGradTensorDescFn(user_op::InferContext* ctx) { const Shape& mask_shape = ctx->InputShape("mask", 0); const Shape& device_count_shape = ctx->InputShape("device_count", 0); CHECK_EQ_OR_RETURN(device_count_shape, mask_shape); - *ctx->OutputShape("in_diff", 0) = mask_shape; + *ctx->MutOutputShape("in_diff", 0) = mask_shape; return Maybe::Ok(); } diff --git a/oneflow/user/ops/unfold_fold_op.cpp b/oneflow/user/ops/unfold_fold_op.cpp index 0560561604c..ce851cce8c7 100644 --- a/oneflow/user/ops/unfold_fold_op.cpp +++ b/oneflow/user/ops/unfold_fold_op.cpp @@ -58,7 +58,7 @@ Maybe UnfoldTensorDescInferFn(user_op::InferContext* ctx) { * std::accumulate(kernel_size.begin(), kernel_size.end(), 1, std::multiplies()); y_shape.at(2) = std::accumulate(dhw_shape.begin(), dhw_shape.end(), 1, std::multiplies()); - *ctx->OutputShape("y", 0) = Shape(y_shape); + *ctx->MutOutputShape("y", 0) = Shape(y_shape); return Maybe::Ok(); } @@ -118,7 +118,7 @@ Maybe FoldTensorDescInferFn(user_op::InferContext* ctx) { y_shape.at(2) = output_size[0]; y_shape.at(3) = output_size[1]; - *ctx->OutputShape("y", 0) = Shape(y_shape); + *ctx->MutOutputShape("y", 0) = Shape(y_shape); return Maybe::Ok(); } diff --git a/oneflow/user/ops/unfold_tensor_op.cpp b/oneflow/user/ops/unfold_tensor_op.cpp index 52d1c068e6b..03c24c7bc29 100644 --- a/oneflow/user/ops/unfold_tensor_op.cpp +++ b/oneflow/user/ops/unfold_tensor_op.cpp @@ -57,7 +57,7 @@ namespace oneflow { out_shape.at(d) = in_size_at_d; } } - *ctx->OutputShape("y", 0) = Shape(out_shape); + *ctx->MutOutputShape("y", 0) = Shape(out_shape); return Maybe::Ok(); } /*static*/ Maybe UnfoldTensorOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { diff --git a/oneflow/user/ops/unsorted_segment_sum_op.cpp b/oneflow/user/ops/unsorted_segment_sum_op.cpp index 5df5e81e451..76d03477f23 100644 --- a/oneflow/user/ops/unsorted_segment_sum_op.cpp +++ b/oneflow/user/ops/unsorted_segment_sum_op.cpp @@ -52,7 +52,7 @@ namespace oneflow { const Shape& data_shape = ctx->InputShape("data", 0); const int64_t axis = ctx->Attr("axis"); const int64_t num_segments = ctx->Attr("num_segments"); - Shape* out_shape = ctx->OutputShape("out", 0); + Shape* out_shape = ctx->MutOutputShape("out", 0); const Shape& segment_ids_shape = ctx->InputShape("segment_ids", 0); DimVector dim_vec; @@ -163,7 +163,7 @@ REGISTER_USER_OP_GRAD("unsorted_segment_sum") FOR_RANGE(int64_t, i, axis + 1, like_shape.NumAxes()) { CHECK_EQ_OR_RETURN(like_shape.At(i), data_shape.At(i + segment_ids_shape.NumAxes() - 1)); } - *ctx->OutputShape("out", 0) = ctx->InputShape("like", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("like", 0); *ctx->IsDynamic4ArgNameAndIndex("out", 0) = ctx->InputIsDynamic("like", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/upsample_op.cpp b/oneflow/user/ops/upsample_op.cpp index e1d05c1b097..2edea6f8b12 100644 --- a/oneflow/user/ops/upsample_op.cpp +++ b/oneflow/user/ops/upsample_op.cpp @@ -244,7 +244,7 @@ namespace oneflow { } /*static*/ Maybe UpsampleLinear1DGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); + Shape* dx_shape = ctx->MutOutputShape("dx", 0); CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" && dy_shape.NumAxes() == 3) << "upsample_linear_1d_grad only supports NCH"; @@ -269,7 +269,7 @@ namespace oneflow { } /*static*/ Maybe UpsampleNearest1DGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); + Shape* dx_shape = ctx->MutOutputShape("dx", 0); CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" && dy_shape.NumAxes() == 3) << "upsample_nearest_1d_grad only supports NCH"; @@ -295,7 +295,7 @@ namespace oneflow { } /*static*/ Maybe UpsampleNearest2DGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); + Shape* dx_shape = ctx->MutOutputShape("dx", 0); CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" && dy_shape.NumAxes() == 4) << "upsample_nearest_2d_grad only supports NCHW"; @@ -322,7 +322,7 @@ namespace oneflow { /*static*/ Maybe UpsampleBilinear2DGradOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); + Shape* dx_shape = ctx->MutOutputShape("dx", 0); CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" && dy_shape.NumAxes() == 4) << "upsample_bilinear_2d_grad only supports NCHW"; @@ -348,7 +348,7 @@ namespace oneflow { } /*static*/ Maybe UpsampleBicubic2DGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); + Shape* dx_shape = ctx->MutOutputShape("dx", 0); CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" && dy_shape.NumAxes() == 4) << "upsample_bicubic_2d_grad only supports NCHW"; @@ -374,7 +374,7 @@ namespace oneflow { } /*static*/ Maybe UpsampleNearest3DGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); + Shape* dx_shape = ctx->MutOutputShape("dx", 0); CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" && dy_shape.NumAxes() == 5) << "upsample_nearest_3d_grad only supports NCDHW"; @@ -401,7 +401,7 @@ namespace oneflow { /*static*/ Maybe UpsampleTrilinear3DGradOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); + Shape* dx_shape = ctx->MutOutputShape("dx", 0); CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" && dy_shape.NumAxes() == 5) << "upsample_trilinear_3d_grad only supports NCDHW"; diff --git a/oneflow/user/ops/util_ops.cpp b/oneflow/user/ops/util_ops.cpp index 0be4ce5f115..2b4a68a986a 100644 --- a/oneflow/user/ops/util_ops.cpp +++ b/oneflow/user/ops/util_ops.cpp @@ -19,7 +19,7 @@ limitations under the License. namespace oneflow { /* static */ Maybe IsNanOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); return Maybe::Ok(); } @@ -43,7 +43,7 @@ namespace oneflow { } /* static */ Maybe IsInfOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/variance_op.cpp b/oneflow/user/ops/variance_op.cpp index c1e578e6947..33caa475c58 100644 --- a/oneflow/user/ops/variance_op.cpp +++ b/oneflow/user/ops/variance_op.cpp @@ -27,7 +27,7 @@ Maybe VarOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const AxisVector reduce_axes_vec = {reduce_axes.begin(), reduce_axes.end()}; const Shape& reduce_shape = CreateReducedShape(input_shape, reduce_axes_vec); const bool keepdim = ctx->Attr("keepdim"); - Shape* output_shape = ctx->OutputShape("output", 0); + Shape* output_shape = ctx->MutOutputShape("output", 0); if (keepdim) { *output_shape = reduce_shape; } else { diff --git a/oneflow/user/ops/vector_matrix_product_op.cpp b/oneflow/user/ops/vector_matrix_product_op.cpp index 834ace4ab4c..6d85721cd30 100644 --- a/oneflow/user/ops/vector_matrix_product_op.cpp +++ b/oneflow/user/ops/vector_matrix_product_op.cpp @@ -26,7 +26,7 @@ Maybe InferTensorDesc4VectorMatrixProduct(user_op::InferContext* ctx) { int64_t k = a.shape().At(0); CHECK_EQ_OR_RETURN(k, b.shape().At(0)) << "Dim K should be equal to vector b's dim0. "; int64_t n = b.shape().At(1); - *ctx->OutputShape("out", 0) = Shape({n}); + *ctx->MutOutputShape("out", 0) = Shape({n}); return Maybe::Ok(); } @@ -45,7 +45,7 @@ Maybe InferTensorDesc4VectorMatrixProductGradA(user_op::InferContext* ctx) */ const user_op::TensorDesc& b = ctx->InputTensorDesc("b", 0); int64_t k = b.shape().At(0); - *ctx->OutputShape("dx", 0) = Shape({k}); + *ctx->MutOutputShape("dx", 0) = Shape({k}); return Maybe::Ok(); } @@ -58,7 +58,7 @@ Maybe InferTensorDesc4VectorMatrixProductGradB(user_op::InferContext* ctx) const user_op::TensorDesc& a = ctx->InputTensorDesc("a", 0); int64_t k = a.shape().At(0); int64_t n = dy.shape().At(0); - *ctx->OutputShape("dx", 0) = Shape({k, n}); + *ctx->MutOutputShape("dx", 0) = Shape({k, n}); return Maybe::Ok(); } diff --git a/oneflow/user/ops/where_op.cpp b/oneflow/user/ops/where_op.cpp index e49ffb19fe6..4a4baf75285 100644 --- a/oneflow/user/ops/where_op.cpp +++ b/oneflow/user/ops/where_op.cpp @@ -81,11 +81,11 @@ Maybe InferWhereTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); const Shape& y_shape = ctx->InputShape("y", 0); if (x_shape == y_shape && y_shape == cond_shape) { - *ctx->OutputShape("out", 0) = cond_shape; + *ctx->MutOutputShape("out", 0) = cond_shape; } else { Shape max_shape = *JUST(GetBroadcastShape(cond_shape, x_shape)); max_shape = *JUST(GetBroadcastShape(max_shape, y_shape)); - *ctx->OutputShape("out", 0) = max_shape; + *ctx->MutOutputShape("out", 0) = max_shape; } return Maybe::Ok(); } @@ -94,10 +94,10 @@ Maybe InferWhereXScalarTensorDesc(user_op::InferContext* ctx) { const Shape& cond_shape = ctx->InputShape("condition", 0); const Shape& y_shape = ctx->InputShape("y", 0); if (cond_shape == y_shape) { - *ctx->OutputShape("out", 0) = cond_shape; + *ctx->MutOutputShape("out", 0) = cond_shape; } else { Shape max_shape = *JUST(GetBroadcastShape(cond_shape, y_shape)); - *ctx->OutputShape("out", 0) = max_shape; + *ctx->MutOutputShape("out", 0) = max_shape; } return Maybe::Ok(); } @@ -106,16 +106,16 @@ Maybe InferWhereYScalarTensorDesc(user_op::InferContext* ctx) { const Shape& cond_shape = ctx->InputShape("condition", 0); const Shape& x_shape = ctx->InputShape("x", 0); if (cond_shape == x_shape) { - *ctx->OutputShape("out", 0) = cond_shape; + *ctx->MutOutputShape("out", 0) = cond_shape; } else { Shape max_shape = *JUST(GetBroadcastShape(cond_shape, x_shape)); - *ctx->OutputShape("out", 0) = max_shape; + *ctx->MutOutputShape("out", 0) = max_shape; } return Maybe::Ok(); } Maybe InferWhereXYScalarTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("condition", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("condition", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/zero_like_op.cpp b/oneflow/user/ops/zero_like_op.cpp index ad648779684..e301865998f 100644 --- a/oneflow/user/ops/zero_like_op.cpp +++ b/oneflow/user/ops/zero_like_op.cpp @@ -33,7 +33,7 @@ namespace oneflow { return Maybe::Ok(); } /*static*/ Maybe ZeroLikeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("like", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("like", 0); return Maybe::Ok(); } /*static*/ Maybe ZeroLikeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { From eda0dc3d8e5c56a25d678c47c846821be5a576c1 Mon Sep 17 00:00:00 2001 From: clackhan Date: Wed, 20 Jul 2022 17:36:14 +0800 Subject: [PATCH 39/67] define_mut_output_shape_and_mut_output_stride_in_infer_ctx --- oneflow/core/framework/infer_util.cpp | 2 +- oneflow/core/framework/infer_util.h | 12 ++-- oneflow/core/framework/op_expr.cpp | 34 +++++++++-- oneflow/core/framework/op_kernel.cpp | 2 +- oneflow/core/kernel/user_kernel.cpp | 28 +++++++-- oneflow/core/operator/user_op.cpp | 38 +++++++++--- oneflow/ir/oneflow-extension/extension.cpp | 2 +- ...ttention_query_mul_key_and_value_kernel.cu | 4 +- ...random_batch_permutation_indices_kernel.cu | 4 +- .../kernels/nccl_logical_send_recv_kernel.cpp | 4 +- oneflow/user/kernels/nms_kernel.cu | 4 +- oneflow/user/kernels/stateful_opkernel.cpp | 58 ++++++++++++++----- .../user/kernels/two_stage_reduce_kernel.cpp | 8 +-- .../kernels/unsorted_segment_sum_kernel.cpp | 4 +- oneflow/user/kernels/where_kernel.cpp | 20 +++---- oneflow/user/ops/acc_op.cpp | 2 +- oneflow/user/ops/adaptive_pool_op.cpp | 4 +- oneflow/user/ops/arange_op.cpp | 4 +- oneflow/user/ops/arg_sort_op.cpp | 2 +- oneflow/user/ops/argmax_op.cpp | 2 +- oneflow/user/ops/avg_pool_op.cpp | 4 +- oneflow/user/ops/bias_add_op.cpp | 2 +- oneflow/user/ops/broadcast_div_grad_op.cpp | 2 +- oneflow/user/ops/broadcast_like_op.cpp | 4 +- oneflow/user/ops/broadcast_pow_grad_op.cpp | 4 +- oneflow/user/ops/buffer_op.cpp | 2 +- oneflow/user/ops/cast_like_op.cpp | 2 +- oneflow/user/ops/cast_to_tick_op.cpp | 2 +- .../ops/categorical_ordinal_encode_op.cpp | 4 +- oneflow/user/ops/celu_op.cpp | 4 +- oneflow/user/ops/clip_by_value_op.cpp | 4 +- oneflow/user/ops/combined_margin_loss_op.cpp | 4 +- oneflow/user/ops/constant_op.cpp | 4 +- oneflow/user/ops/conv_op.cpp | 2 +- oneflow/user/ops/copy_op.cpp | 4 +- oneflow/user/ops/ctc_loss_op.cpp | 10 ++-- .../cublas_bias_add_relu_matmul_grad_op.cpp | 4 +- .../cublas_fused_matmul_bias_add_grad_op.cpp | 4 +- oneflow/user/ops/cublas_fused_mlp_grad_op.cpp | 6 +- oneflow/user/ops/cublas_fused_mlp_op.cpp | 6 +- oneflow/user/ops/cum_ops.cpp | 6 +- oneflow/user/ops/data_shuffle_op.cpp | 24 ++++---- oneflow/user/ops/distributions/normal_op.cpp | 4 +- .../user/ops/distributions/uniform_int_op.cpp | 4 +- oneflow/user/ops/distributions/uniform_op.cpp | 4 +- oneflow/user/ops/dot_op.cpp | 2 +- oneflow/user/ops/dropout_op.cpp | 8 +-- oneflow/user/ops/eager_b_to_s_op.cpp | 2 +- oneflow/user/ops/eager_nccl_ops.cpp | 14 ++--- oneflow/user/ops/eager_p_to_b_op.cpp | 2 +- oneflow/user/ops/eager_p_to_s_op.cpp | 2 +- oneflow/user/ops/eager_s_to_b_op.cpp | 2 +- oneflow/user/ops/eager_s_to_p_op.cpp | 2 +- oneflow/user/ops/eager_s_to_s_op.cpp | 2 +- .../user/ops/eager_symmetric_s_to_p_op.cpp | 2 +- oneflow/user/ops/elu_op.cpp | 4 +- oneflow/user/ops/embedding_op.cpp | 2 +- oneflow/user/ops/empty_op.cpp | 8 +-- oneflow/user/ops/erfinv_op.cpp | 2 +- oneflow/user/ops/expand_dims_op.cpp | 2 +- oneflow/user/ops/expand_op.cpp | 4 +- oneflow/user/ops/eye_op.cpp | 2 +- oneflow/user/ops/fake_quantization_op.cpp | 2 +- oneflow/user/ops/fill_op.cpp | 8 +-- oneflow/user/ops/fused_bias_add_op.cpp | 6 +- .../fused_cross_feature_interaction_op.cpp | 20 +++---- .../ops/fused_dot_feature_interaction_op.cpp | 10 ++-- oneflow/user/ops/fused_gru_cell_op.cpp | 14 ++--- oneflow/user/ops/fused_lstm_cell_op.cpp | 14 +++-- .../fused_matmul_bias_add_relu_dropout_op.cpp | 6 +- .../user/ops/fused_relu_dropout_grad_op.cpp | 2 +- .../fused_scale_mask_softmax_dropout_op.cpp | 4 +- .../user/ops/fused_scale_mask_softmax_op.cpp | 2 +- ...fused_scale_tril_softmax_mask_scale_op.cpp | 4 +- ..._attention_query_mul_key_and_value_ops.cpp | 6 +- oneflow/user/ops/gelu_op.cpp | 4 +- ...te_random_batch_permutation_indices_op.cpp | 2 +- oneflow/user/ops/hardshrink_op.cpp | 4 +- oneflow/user/ops/hardsigmoid_op.cpp | 4 +- oneflow/user/ops/hardswish_op.cpp | 4 +- oneflow/user/ops/hardtanh_op.cpp | 4 +- .../ops/hierarchical_parallel_cast_op.cpp | 4 +- oneflow/user/ops/identity_op.cpp | 2 +- .../user/ops/image_object_preprocess_ops.cpp | 14 ++--- oneflow/user/ops/image_preprocess_ops.cpp | 2 +- .../user/ops/l1_l2_regularize_gradient_op.cpp | 2 +- oneflow/user/ops/l2_normalize_op.cpp | 6 +- oneflow/user/ops/leaky_relu_op.cpp | 4 +- oneflow/user/ops/log_softmax_op.cpp | 4 +- oneflow/user/ops/masked_fill_op.cpp | 2 +- .../user/ops/math_binary_broadcast_ops.cpp | 10 ++-- oneflow/user/ops/matmul_op.cpp | 2 +- oneflow/user/ops/matrix_vector_product_op.cpp | 6 +- oneflow/user/ops/median_op.cpp | 2 +- oneflow/user/ops/median_with_indices_op.cpp | 4 +- oneflow/user/ops/min_max_observer_op.cpp | 12 ++-- oneflow/user/ops/mish_op.cpp | 4 +- oneflow/user/ops/model_update_ops.cpp | 2 +- .../moving_average_min_max_observer_op.cpp | 4 +- oneflow/user/ops/multi_reduce_ops.cpp | 6 +- oneflow/user/ops/narrow_op.cpp | 2 +- oneflow/user/ops/nccl_logical_2d_sbp_ops.cpp | 10 ++-- oneflow/user/ops/nccl_logical_ops.cpp | 14 ++--- oneflow/user/ops/nd_index_slice_ops.cpp | 8 +-- oneflow/user/ops/nms_op.cpp | 2 +- oneflow/user/ops/nvtx_range_op.cpp | 4 +- oneflow/user/ops/one_embedding_ops.cpp | 22 +++---- oneflow/user/ops/ones_like_op.cpp | 4 +- oneflow/user/ops/p2p_comm_op.cpp | 2 +- oneflow/user/ops/pad_op.cpp | 2 +- oneflow/user/ops/padding_ops.cpp | 8 +-- oneflow/user/ops/parallel_cast_op.cpp | 2 +- oneflow/user/ops/partial_fc_sample_op.cpp | 4 +- oneflow/user/ops/prelu_op.cpp | 6 +- oneflow/user/ops/quantization_op.cpp | 2 +- oneflow/user/ops/randperm_op.cpp | 4 +- oneflow/user/ops/reduce_ops.cpp | 4 +- oneflow/user/ops/relu_op.cpp | 4 +- oneflow/user/ops/repeat_op.cpp | 2 +- oneflow/user/ops/reshape_like_op.cpp | 2 +- oneflow/user/ops/roi_align_op.cpp | 4 +- oneflow/user/ops/roll_op.cpp | 2 +- oneflow/user/ops/same_padding_op.cpp | 2 +- oneflow/user/ops/scalar_logical_op.cpp | 2 +- oneflow/user/ops/scalar_math_op.cpp | 6 +- oneflow/user/ops/search_sorted_op.cpp | 4 +- oneflow/user/ops/selu_op.cpp | 4 +- oneflow/user/ops/silu_op.cpp | 4 +- oneflow/user/ops/slice_op.cpp | 6 +- oneflow/user/ops/softmax_cross_entropy_op.cpp | 4 +- oneflow/user/ops/softmax_op.cpp | 4 +- oneflow/user/ops/softplus_op.cpp | 4 +- oneflow/user/ops/softshrink_op.cpp | 4 +- oneflow/user/ops/softsign_op.cpp | 4 +- oneflow/user/ops/sort_op.cpp | 2 +- oneflow/user/ops/sparse_cross_entropy_op.cpp | 2 +- .../ops/sparse_softmax_cross_entropy_op.cpp | 4 +- oneflow/user/ops/squeeze_op.cpp | 2 +- oneflow/user/ops/ssp_variable_proxy_op.cpp | 4 +- oneflow/user/ops/tf_pool_op.cpp | 2 +- oneflow/user/ops/tf_prelu_op.cpp | 2 +- oneflow/user/ops/threshold_op.cpp | 4 +- oneflow/user/ops/to_contiguous_op.cpp | 4 +- oneflow/user/ops/top_k_op.cpp | 2 +- oneflow/user/ops/tuple_identity_op.cpp | 2 +- oneflow/user/ops/two_stage_reduce_ops.cpp | 20 +++---- oneflow/user/ops/unfold_fold_op.cpp | 4 +- oneflow/user/ops/unfold_tensor_op.cpp | 2 +- oneflow/user/ops/unsorted_segment_sum_op.cpp | 4 +- oneflow/user/ops/upsample_op.cpp | 14 ++--- oneflow/user/ops/util_ops.cpp | 4 +- oneflow/user/ops/variance_op.cpp | 2 +- oneflow/user/ops/vector_matrix_product_op.cpp | 6 +- oneflow/user/ops/where_op.cpp | 14 ++--- oneflow/user/ops/zero_like_op.cpp | 2 +- 155 files changed, 497 insertions(+), 399 deletions(-) diff --git a/oneflow/core/framework/infer_util.cpp b/oneflow/core/framework/infer_util.cpp index 599f6a9070d..4ccd9ca7955 100644 --- a/oneflow/core/framework/infer_util.cpp +++ b/oneflow/core/framework/infer_util.cpp @@ -40,7 +40,7 @@ Maybe TensorDescInferFnUtil::Unchanged(InferContext* ctx) { for (size_t i = 0; i < ctx->outputs().size(); ++i) { const std::pair& output_arg = ctx->outputs().at(i); *ctx->OutputIsDynamic(output_arg.first, output_arg.second) = first_tensor_desc->is_dynamic(); - *ctx->OutputShape(output_arg.first, output_arg.second) = first_tensor_desc->shape(); + *ctx->MutOutputShape(output_arg.first, output_arg.second) = first_tensor_desc->shape(); } return Maybe::Ok(); } diff --git a/oneflow/core/framework/infer_util.h b/oneflow/core/framework/infer_util.h index 5b32ea31844..15b77cde0af 100644 --- a/oneflow/core/framework/infer_util.h +++ b/oneflow/core/framework/infer_util.h @@ -43,11 +43,15 @@ class InferContext { virtual const TensorDesc* LogicalTensorDesc4ArgNameAndIndex(const std::string&, int32_t) const = 0; virtual const Shape& InputShape(const std::string&, int32_t) const = 0; - virtual Shape* OutputShape(const std::string&, int32_t) = 0; - virtual Shape* Shape4ArgNameAndIndex(const std::string&, int32_t) = 0; + virtual const Shape& OutputShape(const std::string&, int32_t) const = 0; + virtual Shape* MutOutputShape(const std::string&, int32_t) = 0; + virtual const Shape& Shape4ArgNameAndIndex(const std::string&, int32_t) const = 0; + virtual Shape* MutShape4ArgNameAndIndex(const std::string&, int32_t) = 0; virtual const Stride& InputStride(const std::string&, int32_t) const = 0; - virtual Stride* OutputStride(const std::string&, int32_t) = 0; - virtual Stride* Stride4ArgNameAndIndex(const std::string&, int32_t) = 0; + virtual const Stride& OutputStride(const std::string&, int32_t) const = 0; + virtual Stride* MutOutputStride(const std::string&, int32_t) = 0; + virtual const Stride& Stride4ArgNameAndIndex(const std::string&, int32_t) const = 0; + virtual Stride* MutStride4ArgNameAndIndex(const std::string&, int32_t) = 0; virtual const DataType& InputDType(const std::string&, int32_t) const = 0; virtual DataType* OutputDType(const std::string&, int32_t) = 0; virtual DataType* Dtype4ArgNameAndIndex(const std::string&, int32_t) = 0; diff --git a/oneflow/core/framework/op_expr.cpp b/oneflow/core/framework/op_expr.cpp index 13113237061..9e07d3f0ccc 100644 --- a/oneflow/core/framework/op_expr.cpp +++ b/oneflow/core/framework/op_expr.cpp @@ -221,14 +221,27 @@ class UserOpExprInferContext : public user_op::InferContext { return tensor_meta4input_index_(tuple_index)->shape(); } - Shape* OutputShape(const std::string& name, int32_t index) override { + const Shape& OutputShape(const std::string& name, int32_t index) const override { + const auto& arg_tuple = *user_op_expr_->output_arg_tuple(); + int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); + CHECK_GE(tuple_index, 0); + return tensor_meta4input_index_(tuple_index)->shape(); + } + + Shape* MutOutputShape(const std::string& name, int32_t index) override { const auto& arg_tuple = *user_op_expr_->output_arg_tuple(); int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); CHECK_GE(tuple_index, 0); return tensor_meta4output_index_(tuple_index)->mut_shape(); } - Shape* Shape4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { + const Shape& Shape4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { + return const_cast(this) + ->TensorDesc4ArgNameAndIndex(arg_name, index) + ->shape(); + } + + Shape* MutShape4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { return TensorDesc4ArgNameAndIndex(arg_name, index)->mut_shape(); } @@ -239,14 +252,27 @@ class UserOpExprInferContext : public user_op::InferContext { return tensor_meta4input_index_(tuple_index)->stride(); } - Stride* OutputStride(const std::string& name, int32_t index) override { + const Stride& OutputStride(const std::string& name, int32_t index) const override { + const auto& arg_tuple = *user_op_expr_->input_arg_tuple(); + int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); + CHECK_GE(tuple_index, 0); + return tensor_meta4input_index_(tuple_index)->stride(); + } + + Stride* MutOutputStride(const std::string& name, int32_t index) override { const auto& arg_tuple = *user_op_expr_->output_arg_tuple(); int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); CHECK_GE(tuple_index, 0); return tensor_meta4output_index_(tuple_index)->mut_stride(); } - Stride* Stride4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { + const Stride& Stride4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { + return const_cast(this) + ->TensorDesc4ArgNameAndIndex(arg_name, index) + ->stride(); + } + + Stride* MutStride4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { return TensorDesc4ArgNameAndIndex(arg_name, index)->mut_stride(); } diff --git a/oneflow/core/framework/op_kernel.cpp b/oneflow/core/framework/op_kernel.cpp index 73add18775f..cbbfc59f2d7 100644 --- a/oneflow/core/framework/op_kernel.cpp +++ b/oneflow/core/framework/op_kernel.cpp @@ -25,7 +25,7 @@ void OpKernel::InferShape(KernelInferContext* ctx) const { CHECK_NOTNULL(op_infer_ctx); ctx->GetOpInferFn()(op_infer_ctx); for (const auto& arg_pair : ctx->outputs()) { - const Shape& shape = *op_infer_ctx->OutputShape(arg_pair.first, arg_pair.second); + const Shape& shape = op_infer_ctx->OutputShape(arg_pair.first, arg_pair.second); auto mut_shape_view = ctx->MutShapeView4ArgNameAndIndex(arg_pair.first, arg_pair.second); mut_shape_view.set_shape(shape); } diff --git a/oneflow/core/kernel/user_kernel.cpp b/oneflow/core/kernel/user_kernel.cpp index 12c40c20d2a..0dd9a3c26d2 100644 --- a/oneflow/core/kernel/user_kernel.cpp +++ b/oneflow/core/kernel/user_kernel.cpp @@ -261,21 +261,37 @@ class UserKernelOpInferContext : public user_op::InferContext { return it->second.get(); } const Shape& InputShape(const std::string& arg_name, int32_t index) const override { - return *const_cast(this)->Shape4ArgNameAndIndex(arg_name, index); + return Shape4ArgNameAndIndex(arg_name, index); } - Shape* OutputShape(const std::string& arg_name, int32_t index) override { + const Shape& OutputShape(const std::string& arg_name, int32_t index) const override { return Shape4ArgNameAndIndex(arg_name, index); } - Shape* Shape4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { + Shape* MutOutputShape(const std::string& arg_name, int32_t index) override { + return MutShape4ArgNameAndIndex(arg_name, index); + } + const Shape& Shape4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { + return const_cast(this) + ->TensorDesc4ArgNameAndIndex(arg_name, index) + ->shape(); + } + Shape* MutShape4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { return TensorDesc4ArgNameAndIndex(arg_name, index)->mut_shape(); } const Stride& InputStride(const std::string& arg_name, int32_t index) const override { - return *const_cast(this)->Stride4ArgNameAndIndex(arg_name, index); + return Stride4ArgNameAndIndex(arg_name, index); } - Stride* OutputStride(const std::string& arg_name, int32_t index) override { + const Stride& OutputStride(const std::string& arg_name, int32_t index) const override { return Stride4ArgNameAndIndex(arg_name, index); } - Stride* Stride4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { + Stride* MutOutputStride(const std::string& arg_name, int32_t index) override { + return MutStride4ArgNameAndIndex(arg_name, index); + } + const Stride& Stride4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { + return const_cast(this) + ->TensorDesc4ArgNameAndIndex(arg_name, index) + ->stride(); + } + Stride* MutStride4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { return TensorDesc4ArgNameAndIndex(arg_name, index)->mut_stride(); } const DataType& InputDType(const std::string& arg_name, int32_t index) const override { diff --git a/oneflow/core/operator/user_op.cpp b/oneflow/core/operator/user_op.cpp index e7e9d8c2d2f..01e07032b45 100644 --- a/oneflow/core/operator/user_op.cpp +++ b/oneflow/core/operator/user_op.cpp @@ -171,23 +171,45 @@ class UserOpInferContext final : public user_op::InferContext { } } const Shape& InputShape(const std::string& arg_name, int32_t index) const override { - return *const_cast(this)->Shape4ArgNameAndIndex(arg_name, index); + return Shape4ArgNameAndIndex(arg_name, index); } - Shape* OutputShape(const std::string& arg_name, int32_t index) override { + const Shape& OutputShape(const std::string& arg_name, int32_t index) const override { return Shape4ArgNameAndIndex(arg_name, index); } - Shape* Shape4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { + Shape* MutOutputShape(const std::string& arg_name, int32_t index) override { + return MutShape4ArgNameAndIndex(arg_name, index); + } + const Shape& Shape4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { + auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index)); + if (it == arg2tensor_desc_.end()) { + thread_local static Shape non_shape; + return non_shape; + }; + return it->second.shape(); + } + Shape* MutShape4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index)); if (it == arg2tensor_desc_.end()) { return nullptr; }; return it->second.mut_shape(); } const Stride& InputStride(const std::string& arg_name, int32_t index) const override { - return *const_cast(this)->Stride4ArgNameAndIndex(arg_name, index); + return Stride4ArgNameAndIndex(arg_name, index); } - Stride* OutputStride(const std::string& arg_name, int32_t index) override { + const Stride& OutputStride(const std::string& arg_name, int32_t index) const override { return Stride4ArgNameAndIndex(arg_name, index); } - Stride* Stride4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { + Stride* MutOutputStride(const std::string& arg_name, int32_t index) override { + return MutStride4ArgNameAndIndex(arg_name, index); + } + const Stride& Stride4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { + auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index)); + if (it == arg2tensor_desc_.end()) { + thread_local static Stride non_stride; + return non_stride; + }; + return it->second.stride(); + } + Stride* MutStride4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index)); if (it == arg2tensor_desc_.end()) { return nullptr; }; return it->second.mut_stride(); @@ -612,8 +634,8 @@ Maybe UserOp::InferOutBlobDescs( for (const auto& pair : infer_ctx.outputs()) { BlobDesc* out_blob_desc = GetBlobDesc4BnInOp(GenRepeatedBn(pair.first, pair.second)); out_blob_desc->set_data_type(*(infer_ctx.OutputDType(pair.first, pair.second))); - out_blob_desc->mut_shape() = *(infer_ctx.OutputShape(pair.first, pair.second)); - out_blob_desc->mut_stride() = Stride(*(infer_ctx.OutputShape(pair.first, pair.second))); + out_blob_desc->mut_shape() = infer_ctx.OutputShape(pair.first, pair.second); + out_blob_desc->mut_stride() = Stride(infer_ctx.OutputShape(pair.first, pair.second)); out_blob_desc->set_is_dynamic(*infer_ctx.OutputIsDynamic(pair.first, pair.second)); } return Maybe::Ok(); diff --git a/oneflow/ir/oneflow-extension/extension.cpp b/oneflow/ir/oneflow-extension/extension.cpp index 9954ed6dd8d..78d574b4376 100644 --- a/oneflow/ir/oneflow-extension/extension.cpp +++ b/oneflow/ir/oneflow-extension/extension.cpp @@ -49,7 +49,7 @@ REGISTER_USER_OP("mlir_jit") CHECK_EQ(ctx->inputs().size(), 2); CHECK_EQ(ctx->outputs().size(), 1); const Shape& in_shape = ctx->InputShape("in", 0); - Shape* out_shape = ctx->OutputShape("out", 0); + Shape* out_shape = ctx->MutOutputShape("out", 0); *out_shape = in_shape; *ctx->OutputDType("out", 0) = ctx->InputDType("in", 1); return Maybe::Ok(); diff --git a/oneflow/user/kernels/fused_self_attention_query_mul_key_and_value_kernel.cu b/oneflow/user/kernels/fused_self_attention_query_mul_key_and_value_kernel.cu index 0243ac36ec7..ea49e053512 100644 --- a/oneflow/user/kernels/fused_self_attention_query_mul_key_and_value_kernel.cu +++ b/oneflow/user/kernels/fused_self_attention_query_mul_key_and_value_kernel.cu @@ -266,9 +266,9 @@ class FusedSelfAttentionQueryMulKeyAndValueGradGpuKernel final : public user_op: }; size_t InferTmpBufferSize(user_op::InferContext* ctx) { - const Shape* value_shape = ctx->OutputShape("value", 0); + const Shape& value_shape = ctx->OutputShape("value", 0); DataType value_dtype = *ctx->OutputDType("value", 0); - return value_shape->elem_cnt() * GetSizeOfDataType(value_dtype); + return value_shape.elem_cnt() * GetSizeOfDataType(value_dtype); } size_t InferGradTmpBufferSize(user_op::InferContext* ctx) { diff --git a/oneflow/user/kernels/generate_random_batch_permutation_indices_kernel.cu b/oneflow/user/kernels/generate_random_batch_permutation_indices_kernel.cu index 97ec84abf6d..8928fc5bd9e 100644 --- a/oneflow/user/kernels/generate_random_batch_permutation_indices_kernel.cu +++ b/oneflow/user/kernels/generate_random_batch_permutation_indices_kernel.cu @@ -119,8 +119,8 @@ REGISTER_USER_KERNEL("generate_random_batch_permutation_indices") .SetCreateFn() .SetIsMatchedHob(user_op::HobDeviceType() == DeviceType::kCUDA) .SetInferTmpSizeFn([](oneflow::user_op::InferContext* ctx) { - const Shape* y_shape = ctx->OutputShape("y", 0); - const int32_t batch_size = y_shape->At(0); + const Shape& y_shape = ctx->OutputShape("y", 0); + const int32_t batch_size = y_shape.At(0); const int32_t random_value_aligned_bytes = GetCudaAlignedSize(batch_size * sizeof(float)); const int32_t sorted_value_aligned_bytes = GetCudaAlignedSize(batch_size * sizeof(float)); diff --git a/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp b/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp index 6148e952101..714c9a5cbd3 100644 --- a/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp +++ b/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp @@ -252,7 +252,7 @@ void NcclLogicalSendRecv::Compute(user_op::KernelComputeContext* ctx, user_op::O } size_t InferTmpBufferSize(user_op::InferContext* ctx) { - const Shape* out_shape = ctx->OutputShape("out", 0); + const Shape& out_shape = ctx->OutputShape("out", 0); const user_op::TensorDesc* logical_in_tensor = ctx->LogicalTensorDesc4ArgNameAndIndex("in", 0); const Shape& logical_shape = logical_in_tensor->shape(); const DataType data_type = logical_in_tensor->data_type(); @@ -278,7 +278,7 @@ size_t InferTmpBufferSize(user_op::InferContext* ctx) { } if (NdSbpHasPartialParallel(src_nd_sbp)) { // Note: when src_nd_sbp has partial_sum, need a out_size buffer to copy and add to out. - buf_count += out_shape->elem_cnt(); + buf_count += out_shape.elem_cnt(); } return buf_count * GetSizeOfDataType(data_type); } diff --git a/oneflow/user/kernels/nms_kernel.cu b/oneflow/user/kernels/nms_kernel.cu index 8a1f1785e0e..fa3984af8ab 100644 --- a/oneflow/user/kernels/nms_kernel.cu +++ b/oneflow/user/kernels/nms_kernel.cu @@ -132,8 +132,8 @@ class NmsGpuKernel final : public user_op::OpKernel { && (user_op::HobDataType("out", 0) == DataType::kInt8) \ && (user_op::HobDataType("in", 0) == GetDataType::value)) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ - Shape* in_shape = ctx->Shape4ArgNameAndIndex("in", 0); \ - int64_t num_boxes = in_shape->At(0); \ + const Shape& in_shape = ctx->Shape4ArgNameAndIndex("in", 0); \ + int64_t num_boxes = in_shape.At(0); \ int64_t blocks = CeilDiv(num_boxes, kBlockSize); \ return num_boxes * blocks * sizeof(int64_t); \ }); diff --git a/oneflow/user/kernels/stateful_opkernel.cpp b/oneflow/user/kernels/stateful_opkernel.cpp index 0808219276f..71950fb65f3 100644 --- a/oneflow/user/kernels/stateful_opkernel.cpp +++ b/oneflow/user/kernels/stateful_opkernel.cpp @@ -174,26 +174,42 @@ class UserOpInferContextHelper final { const Shape& InputShape(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { - return *Shape4ArgNameAndIndex(call_ctx, arg_name, index); + return Shape4ArgNameAndIndex(call_ctx, arg_name, index); } - Shape* OutputShape(eager::CallContext* call_ctx, const std::string& arg_name, - int32_t index) const { + const Shape& OutputShape(eager::CallContext* call_ctx, const std::string& arg_name, + int32_t index) const { return Shape4ArgNameAndIndex(call_ctx, arg_name, index); } - Shape* Shape4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, - int32_t index) const { + Shape* MutOutputShape(eager::CallContext* call_ctx, const std::string& arg_name, + int32_t index) const { + return MutShape4ArgNameAndIndex(call_ctx, arg_name, index); + } + const Shape& Shape4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, + int32_t index) const { + return NonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index)->shape(); + } + Shape* MutShape4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, + int32_t index) const { return NonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index)->mut_shape(); } const Stride& InputStride(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { - return *Stride4ArgNameAndIndex(call_ctx, arg_name, index); + return Stride4ArgNameAndIndex(call_ctx, arg_name, index); } - Stride* OutputStride(eager::CallContext* call_ctx, const std::string& arg_name, - int32_t index) const { + const Stride& OutputStride(eager::CallContext* call_ctx, const std::string& arg_name, + int32_t index) const { return Stride4ArgNameAndIndex(call_ctx, arg_name, index); } - Stride* Stride4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, - int32_t index) const { + Stride* MutOutputStride(eager::CallContext* call_ctx, const std::string& arg_name, + int32_t index) const { + return MutStride4ArgNameAndIndex(call_ctx, arg_name, index); + } + const Stride& Stride4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, + int32_t index) const { + return NonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index)->stride(); + } + Stride* MutStride4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, + int32_t index) const { return NonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index)->mut_stride(); } const DataType& InputDType(eager::CallContext* call_ctx, const std::string& arg_name, @@ -317,21 +333,33 @@ class UserOpInferContext : public user_op::InferContext { const Shape& InputShape(const std::string& arg_name, int32_t index) const override { return helper_->InputShape(call_ctx_, arg_name, index); } - Shape* OutputShape(const std::string& arg_name, int32_t index) override { + const Shape& OutputShape(const std::string& arg_name, int32_t index) const override { return helper_->OutputShape(call_ctx_, arg_name, index); } - Shape* Shape4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { + Shape* MutOutputShape(const std::string& arg_name, int32_t index) override { + return helper_->MutOutputShape(call_ctx_, arg_name, index); + } + const Shape& Shape4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { return helper_->Shape4ArgNameAndIndex(call_ctx_, arg_name, index); } + Shape* MutShape4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { + return helper_->MutShape4ArgNameAndIndex(call_ctx_, arg_name, index); + } const Stride& InputStride(const std::string& arg_name, int32_t index) const override { return helper_->InputStride(call_ctx_, arg_name, index); } - Stride* OutputStride(const std::string& arg_name, int32_t index) override { - return helper_->OutputStride(call_ctx_, arg_name, index); + const Stride& OutputStride(const std::string& arg_name, int32_t index) const override { + return helper_->InputStride(call_ctx_, arg_name, index); + } + Stride* MutOutputStride(const std::string& arg_name, int32_t index) override { + return helper_->MutOutputStride(call_ctx_, arg_name, index); } - Stride* Stride4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { + const Stride& Stride4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { return helper_->Stride4ArgNameAndIndex(call_ctx_, arg_name, index); } + Stride* MutStride4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { + return helper_->MutStride4ArgNameAndIndex(call_ctx_, arg_name, index); + } const DataType& InputDType(const std::string& arg_name, int32_t index) const override { return helper_->InputDType(call_ctx_, arg_name, index); } diff --git a/oneflow/user/kernels/two_stage_reduce_kernel.cpp b/oneflow/user/kernels/two_stage_reduce_kernel.cpp index c76eaa9749d..429b0bd0ddf 100644 --- a/oneflow/user/kernels/two_stage_reduce_kernel.cpp +++ b/oneflow/user/kernels/two_stage_reduce_kernel.cpp @@ -127,9 +127,9 @@ template user_op::InferTmpSizeFn GenDeviceStageGradInferTmpSizeFn() { return [](user_op::InferContext* ctx) { const Shape& out_diff_shape = ctx->InputShape("out_diff", 0); - const Shape* in_diff_shape = ctx->OutputShape("in_diff", 0); + const Shape& in_diff_shape = ctx->OutputShape("in_diff", 0); const size_t tmp_bytes = GetCudaAlignedSize(out_diff_shape.elem_cnt() * sizeof(T)); - const size_t broadcasted_tmp_bytes = GetCudaAlignedSize(in_diff_shape->elem_cnt() * sizeof(T)); + const size_t broadcasted_tmp_bytes = GetCudaAlignedSize(in_diff_shape.elem_cnt() * sizeof(T)); return tmp_bytes + broadcasted_tmp_bytes; }; } @@ -259,7 +259,7 @@ user_op::InferTmpSizeFn GenGlobalStageGradInferTmpSizeFn() { return [](user_op::InferContext* ctx) { const Shape& device_count_shape = ctx->InputShape("device_count", 0); const Shape& out_diff_shape = ctx->InputShape("out_diff", 0); - const Shape* in_diff_shape = ctx->OutputShape("in_diff", 0); + const Shape& in_diff_shape = ctx->OutputShape("in_diff", 0); const size_t device_count_with_mask_bytes = GetCudaAlignedSize(device_count_shape.elem_cnt() * sizeof(int32_t)); const size_t global_count_bytes = @@ -268,7 +268,7 @@ user_op::InferTmpSizeFn GenGlobalStageGradInferTmpSizeFn() { GetCudaAlignedSize(device_count_shape.elem_cnt() * sizeof(int32_t)); const size_t divided_buf_bytes = GetCudaAlignedSize(out_diff_shape.elem_cnt() * sizeof(T)); const size_t broadcasted_divided_buf_bytes = - GetCudaAlignedSize(in_diff_shape->elem_cnt() * sizeof(T)); + GetCudaAlignedSize(in_diff_shape.elem_cnt() * sizeof(T)); const size_t total_bytes = device_count_with_mask_bytes + global_count_bytes + reduce_sum_tmp_bytes + divided_buf_bytes + broadcasted_divided_buf_bytes; diff --git a/oneflow/user/kernels/unsorted_segment_sum_kernel.cpp b/oneflow/user/kernels/unsorted_segment_sum_kernel.cpp index bcd7b1c5364..f18bd44f99a 100644 --- a/oneflow/user/kernels/unsorted_segment_sum_kernel.cpp +++ b/oneflow/user/kernels/unsorted_segment_sum_kernel.cpp @@ -193,8 +193,8 @@ class UnsortedSegmentSumHalfKernel final : public user_op::OpKernel { && (user_op::HobDataType("segment_ids", 0) == OF_PP_PAIR_SECOND(segment_ids_type)) \ && (user_op::HobDataType("out", 0) == OF_PP_PAIR_SECOND(out_type))) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ - const Shape* out_shape = ctx->OutputShape("out", 0); \ - return GetCudaAlignedSize(out_shape->elem_cnt() * sizeof(float)); \ + const Shape& out_shape = ctx->OutputShape("out", 0); \ + return GetCudaAlignedSize(out_shape.elem_cnt() * sizeof(float)); \ }); #define REGISTER_UNSORTED_SEGMENT_SUM_HALF_KERNEL_CASE(out_type, segment_ids_type) \ diff --git a/oneflow/user/kernels/where_kernel.cpp b/oneflow/user/kernels/where_kernel.cpp index ee9265f6cf5..0797dd151f7 100644 --- a/oneflow/user/kernels/where_kernel.cpp +++ b/oneflow/user/kernels/where_kernel.cpp @@ -191,13 +191,13 @@ class WhereScalarXYKernel final : public user_op::OpKernel { && (user_op::HobDataType("condition", 0) == OF_PP_PAIR_SECOND(ctype_pair)) \ && (user_op::HobDataType("out", 0) == OF_PP_PAIR_SECOND(dtype_pair))) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ - Shape* out_shape = ctx->OutputShape("out", 0); \ + const Shape& out_shape = ctx->OutputShape("out", 0); \ const size_t x_bytes = \ - GetCudaAlignedSize(out_shape->elem_cnt() * sizeof(OF_PP_PAIR_FIRST(dtype_pair))); \ + GetCudaAlignedSize(out_shape.elem_cnt() * sizeof(OF_PP_PAIR_FIRST(dtype_pair))); \ const size_t y_bytes = \ - GetCudaAlignedSize(out_shape->elem_cnt() * sizeof(OF_PP_PAIR_FIRST(dtype_pair))); \ + GetCudaAlignedSize(out_shape.elem_cnt() * sizeof(OF_PP_PAIR_FIRST(dtype_pair))); \ const size_t cond_bytes = \ - GetCudaAlignedSize(out_shape->elem_cnt() * sizeof(OF_PP_PAIR_FIRST(ctype_pair))); \ + GetCudaAlignedSize(out_shape.elem_cnt() * sizeof(OF_PP_PAIR_FIRST(ctype_pair))); \ return x_bytes + y_bytes + cond_bytes; \ }); @@ -209,11 +209,11 @@ class WhereScalarXYKernel final : public user_op::OpKernel { && (user_op::HobDataType("condition", 0) == OF_PP_PAIR_SECOND(ctype_pair)) \ && (user_op::HobDataType("out", 0) == OF_PP_PAIR_SECOND(dtype_pair))) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ - Shape* out_shape = ctx->OutputShape("out", 0); \ + const Shape& out_shape = ctx->OutputShape("out", 0); \ const size_t y_bytes = \ - GetCudaAlignedSize(out_shape->elem_cnt() * sizeof(OF_PP_PAIR_FIRST(dtype_pair))); \ + GetCudaAlignedSize(out_shape.elem_cnt() * sizeof(OF_PP_PAIR_FIRST(dtype_pair))); \ const size_t cond_bytes = \ - GetCudaAlignedSize(out_shape->elem_cnt() * sizeof(OF_PP_PAIR_FIRST(ctype_pair))); \ + GetCudaAlignedSize(out_shape.elem_cnt() * sizeof(OF_PP_PAIR_FIRST(ctype_pair))); \ return y_bytes + cond_bytes; \ }); @@ -225,11 +225,11 @@ class WhereScalarXYKernel final : public user_op::OpKernel { && (user_op::HobDataType("condition", 0) == OF_PP_PAIR_SECOND(ctype_pair)) \ && (user_op::HobDataType("out", 0) == OF_PP_PAIR_SECOND(dtype_pair))) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ - Shape* out_shape = ctx->OutputShape("out", 0); \ + const Shape& out_shape = ctx->OutputShape("out", 0); \ const size_t x_bytes = \ - GetCudaAlignedSize(out_shape->elem_cnt() * sizeof(OF_PP_PAIR_FIRST(dtype_pair))); \ + GetCudaAlignedSize(out_shape.elem_cnt() * sizeof(OF_PP_PAIR_FIRST(dtype_pair))); \ const size_t cond_bytes = \ - GetCudaAlignedSize(out_shape->elem_cnt() * sizeof(OF_PP_PAIR_FIRST(ctype_pair))); \ + GetCudaAlignedSize(out_shape.elem_cnt() * sizeof(OF_PP_PAIR_FIRST(ctype_pair))); \ return x_bytes + cond_bytes; \ }); diff --git a/oneflow/user/ops/acc_op.cpp b/oneflow/user/ops/acc_op.cpp index 92df9df8f8e..f645c023711 100644 --- a/oneflow/user/ops/acc_op.cpp +++ b/oneflow/user/ops/acc_op.cpp @@ -30,7 +30,7 @@ namespace oneflow { return Maybe::Ok(); } /*static*/ Maybe AccOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/adaptive_pool_op.cpp b/oneflow/user/ops/adaptive_pool_op.cpp index 935e644ea83..35cf44f0c1d 100644 --- a/oneflow/user/ops/adaptive_pool_op.cpp +++ b/oneflow/user/ops/adaptive_pool_op.cpp @@ -31,12 +31,12 @@ Maybe InferFWTensorDesc(user_op::InferContext* ctx) { out_shape[i] = output_size.size() > i - 2 ? output_size[i - 2] : output_size[0]; } - *ctx->OutputShape("y", 0) = Shape(out_shape); + *ctx->MutOutputShape("y", 0) = Shape(out_shape); return Maybe::Ok(); } Maybe InferBWTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("dx", 0) = ctx->InputShape("x", 0); + *ctx->MutOutputShape("dx", 0) = ctx->InputShape("x", 0); *ctx->OutputIsDynamic("dx", 0) = ctx->InputIsDynamic("x", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/arange_op.cpp b/oneflow/user/ops/arange_op.cpp index 73585347376..36a3c954c11 100644 --- a/oneflow/user/ops/arange_op.cpp +++ b/oneflow/user/ops/arange_op.cpp @@ -21,7 +21,7 @@ limitations under the License. namespace oneflow { /* static */ Maybe ArangeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - Shape* out_shape = ctx->OutputShape("out", 0); + Shape* out_shape = ctx->MutOutputShape("out", 0); DataType dtype = ctx->Attr("dtype"); int64_t range_elem_cnt = 0; if (IsIntegralDataType(dtype)) { @@ -88,7 +88,7 @@ namespace oneflow { GetTensorSliceView4ParallelId(parallel_hierarchy, nd_sbp, logical_shape, parallel_id); const Shape& physical_shape = tensor_slice_view.shape(); - *ctx->OutputShape("out", 0) = physical_shape; + *ctx->MutOutputShape("out", 0) = physical_shape; return Maybe::Ok(); } diff --git a/oneflow/user/ops/arg_sort_op.cpp b/oneflow/user/ops/arg_sort_op.cpp index e4ca90915ff..55cf61d6f05 100644 --- a/oneflow/user/ops/arg_sort_op.cpp +++ b/oneflow/user/ops/arg_sort_op.cpp @@ -19,7 +19,7 @@ limitations under the License. namespace oneflow { /* static */ Maybe ArgSortOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/argmax_op.cpp b/oneflow/user/ops/argmax_op.cpp index 58c6581eb29..17cb35709bf 100644 --- a/oneflow/user/ops/argmax_op.cpp +++ b/oneflow/user/ops/argmax_op.cpp @@ -21,7 +21,7 @@ namespace oneflow { /* static */ Maybe ArgmaxOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { auto dim_vec = ctx->InputShape("in", 0).dim_vec(); dim_vec.pop_back(); - *ctx->OutputShape("out", 0) = Shape(std::move(dim_vec)); + *ctx->MutOutputShape("out", 0) = Shape(std::move(dim_vec)); return Maybe::Ok(); } diff --git a/oneflow/user/ops/avg_pool_op.cpp b/oneflow/user/ops/avg_pool_op.cpp index e6d1521707d..23b4f8377ad 100644 --- a/oneflow/user/ops/avg_pool_op.cpp +++ b/oneflow/user/ops/avg_pool_op.cpp @@ -27,7 +27,7 @@ typedef std::function(const user_op::UserOpWrapper& op, user_op::Add TensorDescInferFn AvgPoolMakeForwardTensorDescInferFn(const int32_t dim) { return [dim](user_op::InferContext* ctx) -> Maybe { - const Shape* x_shape = ctx->Shape4ArgNameAndIndex("x", 0); + const Shape& x_shape = ctx->Shape4ArgNameAndIndex("x", 0); const std::string& data_format = ctx->Attr("data_format"); const std::vector& padding = ctx->Attr>("padding"); const std::vector& kernel_size = ctx->Attr>("kernel_size"); @@ -53,7 +53,7 @@ TensorDescInferFn AvgPoolMakeForwardTensorDescInferFn(const int32_t dim) { << "pad should be smaller than half of kernel size"; } - const AvgPoolParams3D params_3d(dim, *x_shape, data_format, padding, kernel_size, stride, + const AvgPoolParams3D params_3d(dim, x_shape, data_format, padding, kernel_size, stride, ceil_mode, count_include_pad, divisor_override); user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); *y_desc = ctx->InputTensorDesc("x", 0); diff --git a/oneflow/user/ops/bias_add_op.cpp b/oneflow/user/ops/bias_add_op.cpp index 77dfff37837..963ac103951 100644 --- a/oneflow/user/ops/bias_add_op.cpp +++ b/oneflow/user/ops/bias_add_op.cpp @@ -35,7 +35,7 @@ namespace oneflow { << Error::RuntimeError() << "The size of tensor " << a_tensor_desc.shape().ToString() << " must match the size of tensor " << b_tensor_desc.shape().ToString() << " at dimension " << bias_add_axis; - *ctx->OutputShape("out", 0) = ctx->InputShape("a", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("a", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("a", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/broadcast_div_grad_op.cpp b/oneflow/user/ops/broadcast_div_grad_op.cpp index c59b2436997..791fa84ad1b 100644 --- a/oneflow/user/ops/broadcast_div_grad_op.cpp +++ b/oneflow/user/ops/broadcast_div_grad_op.cpp @@ -19,7 +19,7 @@ limitations under the License. namespace oneflow { /* static */ Maybe BroadcastDivGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("dy", 0) = ctx->InputShape("y", 0); + *ctx->MutOutputShape("dy", 0) = ctx->InputShape("y", 0); *ctx->OutputIsDynamic("dy", 0) = ctx->InputIsDynamic("y", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/broadcast_like_op.cpp b/oneflow/user/ops/broadcast_like_op.cpp index 1478378ea7f..1e6f1456cac 100644 --- a/oneflow/user/ops/broadcast_like_op.cpp +++ b/oneflow/user/ops/broadcast_like_op.cpp @@ -78,8 +78,8 @@ Maybe InferTensorDesc(user_op::InferContext* ctx) { CHECK_OR_RETURN(!broadcast_axes.empty()); const Shape& in_shape = ctx->InputShape("x", 0); const Shape& like_shape = ctx->InputShape("like", 0); - Shape* out_shape = ctx->OutputShape("y", 0); - Stride* out_stride = ctx->OutputStride("y", 0); + Shape* out_shape = ctx->MutOutputShape("y", 0); + Stride* out_stride = ctx->MutOutputStride("y", 0); const AxisVector axis_vec = {broadcast_axes.begin(), broadcast_axes.end()}; CHECK_OR_RETURN(IsAxesLegal(axis_vec, like_shape, in_shape)); *out_shape = like_shape; diff --git a/oneflow/user/ops/broadcast_pow_grad_op.cpp b/oneflow/user/ops/broadcast_pow_grad_op.cpp index 21fa575b03b..ab23165638a 100644 --- a/oneflow/user/ops/broadcast_pow_grad_op.cpp +++ b/oneflow/user/ops/broadcast_pow_grad_op.cpp @@ -19,7 +19,7 @@ limitations under the License. namespace oneflow { /* static */ Maybe BroadcastPowXGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("dx", 0) = ctx->InputShape("x", 0); + *ctx->MutOutputShape("dx", 0) = ctx->InputShape("x", 0); *ctx->OutputIsDynamic("dx", 0) = ctx->InputIsDynamic("x", 0); return Maybe::Ok(); } @@ -76,7 +76,7 @@ namespace oneflow { } /* static */ Maybe BroadcastPowYGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("dy", 0) = ctx->InputShape("y", 0); + *ctx->MutOutputShape("dy", 0) = ctx->InputShape("y", 0); *ctx->OutputIsDynamic("dy", 0) = ctx->InputIsDynamic("y", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/buffer_op.cpp b/oneflow/user/ops/buffer_op.cpp index eb8abde1ee6..86f8cd1e79e 100644 --- a/oneflow/user/ops/buffer_op.cpp +++ b/oneflow/user/ops/buffer_op.cpp @@ -19,7 +19,7 @@ limitations under the License. namespace oneflow { /* static */ Maybe IdentityBufferOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/cast_like_op.cpp b/oneflow/user/ops/cast_like_op.cpp index c4d41a00be8..77cc334b087 100644 --- a/oneflow/user/ops/cast_like_op.cpp +++ b/oneflow/user/ops/cast_like_op.cpp @@ -19,7 +19,7 @@ limitations under the License. namespace oneflow { /* static */ Maybe CastLikeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/cast_to_tick_op.cpp b/oneflow/user/ops/cast_to_tick_op.cpp index bb76f5887e6..576ca9fc220 100644 --- a/oneflow/user/ops/cast_to_tick_op.cpp +++ b/oneflow/user/ops/cast_to_tick_op.cpp @@ -20,7 +20,7 @@ limitations under the License. namespace oneflow { /* static */ Maybe CastToTickOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - Shape* out_shape = ctx->OutputShape("out", 0); + Shape* out_shape = ctx->MutOutputShape("out", 0); *out_shape = Shape({1}); return Maybe::Ok(); } diff --git a/oneflow/user/ops/categorical_ordinal_encode_op.cpp b/oneflow/user/ops/categorical_ordinal_encode_op.cpp index ca2b4533826..e478d910532 100644 --- a/oneflow/user/ops/categorical_ordinal_encode_op.cpp +++ b/oneflow/user/ops/categorical_ordinal_encode_op.cpp @@ -26,7 +26,7 @@ namespace oneflow { const Shape& size_shape = ctx->InputShape("size", 0); CHECK_EQ_OR_RETURN(size_shape.NumAxes(), 1); CHECK_EQ_OR_RETURN(size_shape.elem_cnt(), 1); - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); return Maybe::Ok(); } @@ -39,7 +39,7 @@ namespace oneflow { const Shape& size_shape = ctx->InputShape("size", 0); CHECK_EQ_OR_RETURN(size_shape.NumAxes(), 1); CHECK_EQ_OR_RETURN(size_shape.elem_cnt(), 1); - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/celu_op.cpp b/oneflow/user/ops/celu_op.cpp index 60d48152434..039124a0f6d 100644 --- a/oneflow/user/ops/celu_op.cpp +++ b/oneflow/user/ops/celu_op.cpp @@ -19,7 +19,7 @@ limitations under the License. namespace oneflow { /* static */ Maybe CeluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); return Maybe::Ok(); } @@ -43,7 +43,7 @@ namespace oneflow { /* static */ Maybe CeluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); + Shape* dx_shape = ctx->MutOutputShape("dx", 0); CHECK_OR_RETURN(dy_shape == x_shape); *dx_shape = dy_shape; return Maybe::Ok(); diff --git a/oneflow/user/ops/clip_by_value_op.cpp b/oneflow/user/ops/clip_by_value_op.cpp index f216e077816..63363bbb153 100644 --- a/oneflow/user/ops/clip_by_value_op.cpp +++ b/oneflow/user/ops/clip_by_value_op.cpp @@ -21,7 +21,7 @@ namespace oneflow { namespace { Maybe InferClipTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("y", 0) = ctx->InputShape("x", 0); + *ctx->MutOutputShape("y", 0) = ctx->InputShape("x", 0); return Maybe::Ok(); } @@ -34,7 +34,7 @@ Maybe GetClipSbpSignature(user_op::SbpContext* ctx) { } Maybe InferClipGradTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("dx", 0) = ctx->InputShape("x", 0); + *ctx->MutOutputShape("dx", 0) = ctx->InputShape("x", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/combined_margin_loss_op.cpp b/oneflow/user/ops/combined_margin_loss_op.cpp index 72854a53928..65b462ac1b0 100644 --- a/oneflow/user/ops/combined_margin_loss_op.cpp +++ b/oneflow/user/ops/combined_margin_loss_op.cpp @@ -24,7 +24,7 @@ namespace oneflow { user_op::TensorDesc* theta = ctx->OutputTensorDesc("theta", 0); CHECK_EQ_OR_RETURN(label.shape().At(0), x.shape().At(0)); CHECK_GE_OR_RETURN(x.shape().NumAxes(), 2); - *ctx->OutputShape("y", 0) = ctx->InputShape("x", 0); + *ctx->MutOutputShape("y", 0) = ctx->InputShape("x", 0); *ctx->IsDynamic4ArgNameAndIndex("y", 0) = ctx->InputIsDynamic("x", 0); *theta->mut_is_dynamic() = x.is_dynamic(); *theta->mut_shape() = label.shape(); @@ -72,7 +72,7 @@ namespace oneflow { CHECK_EQ_OR_RETURN(label.shape().At(0), dy.shape().At(0)); CHECK_EQ_OR_RETURN(label.shape().At(0), theta.shape().At(0)); CHECK_GE_OR_RETURN(dy.shape().NumAxes(), 2); - *ctx->OutputShape("dx", 0) = ctx->InputShape("dy", 0); + *ctx->MutOutputShape("dx", 0) = ctx->InputShape("dy", 0); *ctx->IsDynamic4ArgNameAndIndex("dx", 0) = ctx->InputIsDynamic("dy", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/constant_op.cpp b/oneflow/user/ops/constant_op.cpp index 62d9bdcc050..4a14f638b43 100644 --- a/oneflow/user/ops/constant_op.cpp +++ b/oneflow/user/ops/constant_op.cpp @@ -20,7 +20,7 @@ limitations under the License. namespace oneflow { /* static */ Maybe ConstantOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = Shape(ctx->Attr("shape").dim_vec()); + *ctx->MutOutputShape("out", 0) = Shape(ctx->Attr("shape").dim_vec()); return Maybe::Ok(); } @@ -33,7 +33,7 @@ namespace oneflow { GetTensorSliceView4ParallelId(parallel_hierarchy, nd_sbp, logical_shape, parallel_id); const Shape& physical_shape = tensor_slice_view.shape(); - *ctx->OutputShape("out", 0) = physical_shape; + *ctx->MutOutputShape("out", 0) = physical_shape; return Maybe::Ok(); } diff --git a/oneflow/user/ops/conv_op.cpp b/oneflow/user/ops/conv_op.cpp index 64940f4d2da..ce753a087f3 100644 --- a/oneflow/user/ops/conv_op.cpp +++ b/oneflow/user/ops/conv_op.cpp @@ -308,7 +308,7 @@ Maybe GenerateBackwardOpConf4Conv(const user_op::UserOpWrapper& op, user_o const user_op::TensorDesc& add_to_output = ctx->InputTensorDesc("_add_to_output", 0); CHECK_EQ_OR_RETURN(add_to_output.shape(), x_like.shape()); } - *ctx->OutputShape("dx", 0) = ctx->InputShape("x_like", 0); + *ctx->MutOutputShape("dx", 0) = ctx->InputShape("x_like", 0); *ctx->OutputIsDynamic("dx", 0) = ctx->InputIsDynamic("x_like", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/copy_op.cpp b/oneflow/user/ops/copy_op.cpp index 6b7d5f994f2..f283e7c716a 100644 --- a/oneflow/user/ops/copy_op.cpp +++ b/oneflow/user/ops/copy_op.cpp @@ -42,8 +42,8 @@ Maybe> MakeCopyStream(const Symbol& in_device, } // namespace /* static */ Maybe CopyOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputStride("out", 0) = ctx->InputStride("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputStride("out", 0) = ctx->InputStride("in", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/ctc_loss_op.cpp b/oneflow/user/ops/ctc_loss_op.cpp index b8dee1ad9cc..3b8e466c923 100644 --- a/oneflow/user/ops/ctc_loss_op.cpp +++ b/oneflow/user/ops/ctc_loss_op.cpp @@ -34,8 +34,8 @@ namespace oneflow { CHECK_GE_OR_RETURN(ctx->Attr("blank"), 0); CHECK_LT_OR_RETURN(ctx->Attr("blank"), log_probs.shape().At(2)); - *ctx->OutputShape("loss", 0) = Shape({batch_size}); - *ctx->OutputShape("alpha", 0) = + *ctx->MutOutputShape("loss", 0) = Shape({batch_size}); + *ctx->MutOutputShape("alpha", 0) = Shape({batch_size, log_probs.shape().At(0), 2 * max_target_length + 1}); return Maybe::Ok(); } @@ -78,7 +78,7 @@ namespace oneflow { CHECK_GE_OR_RETURN(ctx->Attr("blank"), 0); CHECK_LT_OR_RETURN(ctx->Attr("blank"), log_probs.shape().At(2)); - *ctx->OutputShape("grad", 0) = log_probs.shape(); + *ctx->MutOutputShape("grad", 0) = log_probs.shape(); return Maybe::Ok(); } @@ -110,8 +110,8 @@ namespace oneflow { const user_op::TensorDesc& input_lengths = ctx->InputTensorDesc("input_lengths", 0); const int64_t batch_size = log_probs.shape().At(1); CHECK_EQ_OR_RETURN(batch_size, input_lengths.shape().At(0)); - *ctx->OutputShape("decoded", 0) = Shape({batch_size, log_probs.shape().At(0)}); - *ctx->OutputShape("neg_sum_logits", 0) = Shape({batch_size, 1}); + *ctx->MutOutputShape("decoded", 0) = Shape({batch_size, log_probs.shape().At(0)}); + *ctx->MutOutputShape("neg_sum_logits", 0) = Shape({batch_size, 1}); return Maybe::Ok(); } diff --git a/oneflow/user/ops/cublas_bias_add_relu_matmul_grad_op.cpp b/oneflow/user/ops/cublas_bias_add_relu_matmul_grad_op.cpp index ae09393bf85..0114b96336a 100644 --- a/oneflow/user/ops/cublas_bias_add_relu_matmul_grad_op.cpp +++ b/oneflow/user/ops/cublas_bias_add_relu_matmul_grad_op.cpp @@ -28,8 +28,8 @@ Maybe InferTensorDesc4FusedMatmulBackward(user_op::InferContext* ctx) { const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); const int64_t bias_size = weight_desc.shape().At(1); Shape d_grad_shape({dy_desc.shape().At(0), weight_desc.shape().At(1)}); - *ctx->OutputShape("d_grad", 0) = d_grad_shape; - *ctx->OutputShape("d_bias", 0) = Shape({bias_size}); + *ctx->MutOutputShape("d_grad", 0) = d_grad_shape; + *ctx->MutOutputShape("d_bias", 0) = Shape({bias_size}); return Maybe::Ok(); } diff --git a/oneflow/user/ops/cublas_fused_matmul_bias_add_grad_op.cpp b/oneflow/user/ops/cublas_fused_matmul_bias_add_grad_op.cpp index 8ae2e512d62..58e9b5e6912 100644 --- a/oneflow/user/ops/cublas_fused_matmul_bias_add_grad_op.cpp +++ b/oneflow/user/ops/cublas_fused_matmul_bias_add_grad_op.cpp @@ -36,8 +36,8 @@ Maybe InferTensorDesc4MatmulBiasAddBackward(user_op::InferContext* ctx) { const int64_t bias_size = dy_desc.shape().At(1); Shape w_grad_shape({dy_desc.shape().At(1), x_desc.shape().At(1)}); - *ctx->OutputShape("w_grad", 0) = w_grad_shape; - *ctx->OutputShape("b_grad", 0) = Shape({bias_size}); + *ctx->MutOutputShape("w_grad", 0) = w_grad_shape; + *ctx->MutOutputShape("b_grad", 0) = Shape({bias_size}); return Maybe::Ok(); } diff --git a/oneflow/user/ops/cublas_fused_mlp_grad_op.cpp b/oneflow/user/ops/cublas_fused_mlp_grad_op.cpp index cf4fd9d3bcd..f21853568a1 100644 --- a/oneflow/user/ops/cublas_fused_mlp_grad_op.cpp +++ b/oneflow/user/ops/cublas_fused_mlp_grad_op.cpp @@ -25,10 +25,10 @@ Maybe InferTensorDesc4FusedMatmulBackward(user_op::InferContext* ctx) { const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); for (int idx = weight_num - 1; idx >= 0; idx--) { const user_op::TensorDesc& weight_desc = ctx->InputTensorDesc("weights", idx); - *ctx->OutputShape("d_weights", idx) = weight_desc.shape(); - *ctx->OutputShape("d_biases", idx) = Shape({weight_desc.shape().At(0)}); + *ctx->MutOutputShape("d_weights", idx) = weight_desc.shape(); + *ctx->MutOutputShape("d_biases", idx) = Shape({weight_desc.shape().At(0)}); } - *ctx->OutputShape("d_x", 0) = x_desc.shape(); + *ctx->MutOutputShape("d_x", 0) = x_desc.shape(); return Maybe::Ok(); } diff --git a/oneflow/user/ops/cublas_fused_mlp_op.cpp b/oneflow/user/ops/cublas_fused_mlp_op.cpp index 9bc5d9f1b57..9369a0e303c 100644 --- a/oneflow/user/ops/cublas_fused_mlp_op.cpp +++ b/oneflow/user/ops/cublas_fused_mlp_op.cpp @@ -65,12 +65,12 @@ Maybe InferTensorDesc4FusedMatmul(user_op::InferContext* ctx) { // Set Middle result shape. long cublas_aligned_aux_ld = AlignReluAuxLd(cublas_aux_ld); int64_t aux_size = cublas_aligned_aux_ld / 32; // Cause we use int32_t as dtype - *ctx->OutputShape("cublas_aux", idx) = Shape({m, aux_size}); - *ctx->OutputShape("hidden", idx) = Shape({m, n}); + *ctx->MutOutputShape("cublas_aux", idx) = Shape({m, aux_size}); + *ctx->MutOutputShape("hidden", idx) = Shape({m, n}); // Set for next layer. k = n; } - *ctx->OutputShape("out", 0) = {m, n}; + *ctx->MutOutputShape("out", 0) = {m, n}; return Maybe::Ok(); } diff --git a/oneflow/user/ops/cum_ops.cpp b/oneflow/user/ops/cum_ops.cpp index 265a201119d..9ee5b5c123a 100644 --- a/oneflow/user/ops/cum_ops.cpp +++ b/oneflow/user/ops/cum_ops.cpp @@ -19,7 +19,7 @@ limitations under the License. namespace oneflow { Maybe CumsumOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("y", 0) = ctx->InputShape("x", 0); + *ctx->MutOutputShape("y", 0) = ctx->InputShape("x", 0); return Maybe::Ok(); } @@ -73,7 +73,7 @@ REGISTER_USER_OP_GRAD("cumsum").SetGenBackwardOpConfFn( }); Maybe CumProdOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("y", 0) = ctx->InputShape("x", 0); + *ctx->MutOutputShape("y", 0) = ctx->InputShape("x", 0); return Maybe::Ok(); } @@ -96,7 +96,7 @@ Maybe CumProdOp::InferDataType(user_op::InferContext* ctx) { } Maybe CumProdGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("dx", 0) = ctx->InputShape("dy", 0); + *ctx->MutOutputShape("dx", 0) = ctx->InputShape("dy", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/data_shuffle_op.cpp b/oneflow/user/ops/data_shuffle_op.cpp index e8e3ebfa9fa..3f0a4b9abb9 100644 --- a/oneflow/user/ops/data_shuffle_op.cpp +++ b/oneflow/user/ops/data_shuffle_op.cpp @@ -32,10 +32,10 @@ namespace oneflow { CHECK_EQ_OR_RETURN(keys_shape.At(1), num_tables) << "keys cols must equal to num_tables"; } } - *ctx->OutputShape("num_unique", 0) = Shape({1}); - *ctx->OutputShape("unique_keys", 0) = Shape({keys_shape.elem_cnt()}); - *ctx->OutputShape("unique_values", 0) = Shape({keys_shape.elem_cnt()}); - *ctx->OutputShape("inverse_indices", 0) = keys_shape; + *ctx->MutOutputShape("num_unique", 0) = Shape({1}); + *ctx->MutOutputShape("unique_keys", 0) = Shape({keys_shape.elem_cnt()}); + *ctx->MutOutputShape("unique_values", 0) = Shape({keys_shape.elem_cnt()}); + *ctx->MutOutputShape("inverse_indices", 0) = keys_shape; return Maybe::Ok(); } @@ -74,12 +74,12 @@ namespace oneflow { } const int64_t num_ids = ids_shape.elem_cnt(); const int64_t parallel_num = ctx->parallel_num(); - *ctx->OutputShape("num_unique_matrix", 0) = Shape({parallel_num * parallel_num}); - *ctx->OutputShape("inverse_unique_partition_indices", 0) = ids_shape; - *ctx->OutputShape("cur_rank_num_unique", 0) = Shape({1}); - *ctx->OutputShape("cur_rank_unique_ids", 0) = Shape({num_ids * parallel_num}); - *ctx->OutputShape("cur_rank_inverse_indices", 0) = Shape({num_ids * parallel_num}); - *ctx->OutputShape("cur_rank_unique_table_ids", 0) = Shape({num_ids * parallel_num}); + *ctx->MutOutputShape("num_unique_matrix", 0) = Shape({parallel_num * parallel_num}); + *ctx->MutOutputShape("inverse_unique_partition_indices", 0) = ids_shape; + *ctx->MutOutputShape("cur_rank_num_unique", 0) = Shape({1}); + *ctx->MutOutputShape("cur_rank_unique_ids", 0) = Shape({num_ids * parallel_num}); + *ctx->MutOutputShape("cur_rank_inverse_indices", 0) = Shape({num_ids * parallel_num}); + *ctx->MutOutputShape("cur_rank_unique_table_ids", 0) = Shape({num_ids * parallel_num}); return Maybe::Ok(); } @@ -135,7 +135,7 @@ namespace oneflow { CHECK_EQ_OR_RETURN(cur_rank_inverse_indices_shape.elem_cnt(), parallel_num * num_ids); DimVector out_dim_vec = inverse_unique_partition_indices_shape.dim_vec(); out_dim_vec.push_back(embedding_size); - *ctx->OutputShape("embeddings", 0) = Shape(out_dim_vec); + *ctx->MutOutputShape("embeddings", 0) = Shape(out_dim_vec); return Maybe::Ok(); } @@ -179,7 +179,7 @@ namespace oneflow { CHECK_EQ_OR_RETURN(cur_rank_inverse_indices_shape.elem_cnt(), parallel_num * num_ids); DimVector out_dim_vec = cur_rank_inverse_indices_shape.dim_vec(); out_dim_vec.push_back(embedding_size); - *ctx->OutputShape("cur_rank_unique_embedding_grad", 0) = Shape(out_dim_vec); + *ctx->MutOutputShape("cur_rank_unique_embedding_grad", 0) = Shape(out_dim_vec); return Maybe::Ok(); } diff --git a/oneflow/user/ops/distributions/normal_op.cpp b/oneflow/user/ops/distributions/normal_op.cpp index 736a70e5d0b..769ff12dd2e 100644 --- a/oneflow/user/ops/distributions/normal_op.cpp +++ b/oneflow/user/ops/distributions/normal_op.cpp @@ -21,7 +21,7 @@ limitations under the License. namespace oneflow { /* static */ Maybe NormalOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - Shape* out_shape = ctx->OutputShape("out", 0); + Shape* out_shape = ctx->MutOutputShape("out", 0); const Shape& shape = ctx->Attr("shape"); *out_shape = shape; return Maybe::Ok(); @@ -36,7 +36,7 @@ namespace oneflow { GetTensorSliceView4ParallelId(parallel_hierarchy, nd_sbp, logical_shape, parallel_id); const Shape& physical_shape = tensor_slice_view.shape(); - *ctx->OutputShape("out", 0) = physical_shape; + *ctx->MutOutputShape("out", 0) = physical_shape; return Maybe::Ok(); } diff --git a/oneflow/user/ops/distributions/uniform_int_op.cpp b/oneflow/user/ops/distributions/uniform_int_op.cpp index f01bb710f3c..63b0e39d74d 100644 --- a/oneflow/user/ops/distributions/uniform_int_op.cpp +++ b/oneflow/user/ops/distributions/uniform_int_op.cpp @@ -20,7 +20,7 @@ limitations under the License. namespace oneflow { /* static */ Maybe UniformIntOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - Shape* out_shape = ctx->OutputShape("out", 0); + Shape* out_shape = ctx->MutOutputShape("out", 0); const Shape& shape = ctx->Attr("shape"); DimVector dim_vec; if (shape.NumAxes() > 0) { @@ -39,7 +39,7 @@ namespace oneflow { GetTensorSliceView4ParallelId(parallel_hierarchy, nd_sbp, logical_shape, parallel_id); const Shape& physical_shape = tensor_slice_view.shape(); - *ctx->OutputShape("out", 0) = physical_shape; + *ctx->MutOutputShape("out", 0) = physical_shape; return Maybe::Ok(); } diff --git a/oneflow/user/ops/distributions/uniform_op.cpp b/oneflow/user/ops/distributions/uniform_op.cpp index b7d566aac49..3ccb8400fab 100644 --- a/oneflow/user/ops/distributions/uniform_op.cpp +++ b/oneflow/user/ops/distributions/uniform_op.cpp @@ -20,7 +20,7 @@ limitations under the License. namespace oneflow { /* static */ Maybe UniformOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - Shape* out_shape = ctx->OutputShape("out", 0); + Shape* out_shape = ctx->MutOutputShape("out", 0); const Shape& shape = ctx->Attr("shape"); DimVector dim_vec; if (shape.NumAxes() > 0) { @@ -39,7 +39,7 @@ namespace oneflow { GetTensorSliceView4ParallelId(parallel_hierarchy, nd_sbp, logical_shape, parallel_id); const Shape& physical_shape = tensor_slice_view.shape(); - *ctx->OutputShape("out", 0) = physical_shape; + *ctx->MutOutputShape("out", 0) = physical_shape; return Maybe::Ok(); } diff --git a/oneflow/user/ops/dot_op.cpp b/oneflow/user/ops/dot_op.cpp index 080a8cff539..7ea24b0d9f8 100644 --- a/oneflow/user/ops/dot_op.cpp +++ b/oneflow/user/ops/dot_op.cpp @@ -28,7 +28,7 @@ namespace oneflow { CHECK_OR_RETURN(x.shape().NumAxes() == 1) << Error::RuntimeError() << "1D tensors expected, but got " << x.shape().NumAxes() << "D tensors"; - *ctx->OutputShape("out", 0) = Shape({}); + *ctx->MutOutputShape("out", 0) = Shape({}); return Maybe::Ok(); } diff --git a/oneflow/user/ops/dropout_op.cpp b/oneflow/user/ops/dropout_op.cpp index c23d2ef28af..b74deb9ac06 100644 --- a/oneflow/user/ops/dropout_op.cpp +++ b/oneflow/user/ops/dropout_op.cpp @@ -20,8 +20,8 @@ namespace oneflow { /* static */ Maybe DropoutOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& in_shape = ctx->InputShape("in", 0); - *ctx->OutputShape("out", 0) = in_shape; - *ctx->OutputShape("mask", 0) = in_shape; + *ctx->MutOutputShape("out", 0) = in_shape; + *ctx->MutOutputShape("mask", 0) = in_shape; *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } @@ -53,7 +53,7 @@ namespace oneflow { /* static */ Maybe DropoutGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& dy_shape = ctx->InputShape("dy", 0); - *ctx->OutputShape("dx", 0) = dy_shape; + *ctx->MutOutputShape("dx", 0) = dy_shape; *ctx->OutputIsDynamic("dx", 0) = ctx->InputIsDynamic("dy", 0); CHECK_EQ_OR_RETURN(ctx->InputShape("mask", 0), dy_shape); return Maybe::Ok(); @@ -89,7 +89,7 @@ namespace oneflow { } /* static */ Maybe RandomMaskLikeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("like", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("like", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/eager_b_to_s_op.cpp b/oneflow/user/ops/eager_b_to_s_op.cpp index 00cb6aee242..1d415e230f4 100644 --- a/oneflow/user/ops/eager_b_to_s_op.cpp +++ b/oneflow/user/ops/eager_b_to_s_op.cpp @@ -39,7 +39,7 @@ namespace oneflow { int64_t parallel_id = opt_parallel_id->value_or(0); dim_vec[out_split_axis] = bs.At(parallel_id).size(); } - *ctx->OutputShape("out", 0) = Shape(dim_vec); + *ctx->MutOutputShape("out", 0) = Shape(dim_vec); return Maybe::Ok(); } diff --git a/oneflow/user/ops/eager_nccl_ops.cpp b/oneflow/user/ops/eager_nccl_ops.cpp index 5f574a7b1be..8af86554f51 100644 --- a/oneflow/user/ops/eager_nccl_ops.cpp +++ b/oneflow/user/ops/eager_nccl_ops.cpp @@ -24,7 +24,7 @@ limitations under the License. namespace oneflow { /* static */ Maybe EagerNcclAllReduceOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); return Maybe::Ok(); } @@ -48,7 +48,7 @@ namespace oneflow { } /* static */ Maybe EagerNcclBroadcastOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); return Maybe::Ok(); } @@ -96,7 +96,7 @@ namespace oneflow { } /* static */ Maybe EagerNcclReduceOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); return Maybe::Ok(); } @@ -120,14 +120,14 @@ namespace oneflow { /* static */ Maybe EagerNcclReduceScatterOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); return Maybe::Ok(); } /* static */ Maybe EagerNcclReduceScatterOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { const Shape& in_shape = ctx->InputShape("in", 0); - Shape* out_shape = ctx->OutputShape("out", 0); + Shape* out_shape = ctx->MutOutputShape("out", 0); const int64_t& parallel_num = ctx->parallel_ctx().parallel_num(); if (parallel_num > 1) { const Shape& parallel_hierarchy = *ctx->parallel_desc().hierarchy(); @@ -179,7 +179,7 @@ namespace oneflow { } /* static */ Maybe EagerNcclAllGatherOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } @@ -226,7 +226,7 @@ namespace oneflow { } /* static */ Maybe EagerNcclS2sOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/eager_p_to_b_op.cpp b/oneflow/user/ops/eager_p_to_b_op.cpp index f503dfcefd9..e1ad0d5ca3c 100644 --- a/oneflow/user/ops/eager_p_to_b_op.cpp +++ b/oneflow/user/ops/eager_p_to_b_op.cpp @@ -24,7 +24,7 @@ limitations under the License. namespace oneflow { // Can only be called in local /* static */ Maybe EagerPToBOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = Shape(ctx->Attr("shape").dim_vec()); + *ctx->MutOutputShape("out", 0) = Shape(ctx->Attr("shape").dim_vec()); return Maybe::Ok(); } diff --git a/oneflow/user/ops/eager_p_to_s_op.cpp b/oneflow/user/ops/eager_p_to_s_op.cpp index d05bb50df12..1731cf321e2 100644 --- a/oneflow/user/ops/eager_p_to_s_op.cpp +++ b/oneflow/user/ops/eager_p_to_s_op.cpp @@ -38,7 +38,7 @@ namespace oneflow { int64_t parallel_id = opt_parallel_id->value_or(0); dim_vec[out_split_axis] = bs.At(parallel_id).size(); } - *ctx->OutputShape("out", 0) = Shape(dim_vec); + *ctx->MutOutputShape("out", 0) = Shape(dim_vec); return Maybe::Ok(); } diff --git a/oneflow/user/ops/eager_s_to_b_op.cpp b/oneflow/user/ops/eager_s_to_b_op.cpp index e59d98bb520..9c9ff92d53b 100644 --- a/oneflow/user/ops/eager_s_to_b_op.cpp +++ b/oneflow/user/ops/eager_s_to_b_op.cpp @@ -24,7 +24,7 @@ limitations under the License. namespace oneflow { /* static */ Maybe EagerSToBOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = Shape(ctx->Attr("shape").dim_vec()); + *ctx->MutOutputShape("out", 0) = Shape(ctx->Attr("shape").dim_vec()); return Maybe::Ok(); } diff --git a/oneflow/user/ops/eager_s_to_p_op.cpp b/oneflow/user/ops/eager_s_to_p_op.cpp index 711c8d84501..1caa5dfd408 100644 --- a/oneflow/user/ops/eager_s_to_p_op.cpp +++ b/oneflow/user/ops/eager_s_to_p_op.cpp @@ -24,7 +24,7 @@ limitations under the License. namespace oneflow { /* static */ Maybe EagerSToPOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = Shape(ctx->Attr("shape").dim_vec()); + *ctx->MutOutputShape("out", 0) = Shape(ctx->Attr("shape").dim_vec()); return Maybe::Ok(); } diff --git a/oneflow/user/ops/eager_s_to_s_op.cpp b/oneflow/user/ops/eager_s_to_s_op.cpp index f2ec6bc933d..11c36b19649 100644 --- a/oneflow/user/ops/eager_s_to_s_op.cpp +++ b/oneflow/user/ops/eager_s_to_s_op.cpp @@ -38,7 +38,7 @@ namespace oneflow { int64_t parallel_id = opt_parallel_id->value_or(0); dim_vec[out_split_axis] = bs.At(parallel_id).size(); } - *ctx->OutputShape("out", 0) = Shape(dim_vec); + *ctx->MutOutputShape("out", 0) = Shape(dim_vec); return Maybe::Ok(); } diff --git a/oneflow/user/ops/eager_symmetric_s_to_p_op.cpp b/oneflow/user/ops/eager_symmetric_s_to_p_op.cpp index 1767d96e9f4..95a3716d106 100644 --- a/oneflow/user/ops/eager_symmetric_s_to_p_op.cpp +++ b/oneflow/user/ops/eager_symmetric_s_to_p_op.cpp @@ -22,7 +22,7 @@ limitations under the License. namespace oneflow { /* static */ Maybe EagerSymmetricSToPOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/elu_op.cpp b/oneflow/user/ops/elu_op.cpp index 9de85d34655..7d32b87d832 100644 --- a/oneflow/user/ops/elu_op.cpp +++ b/oneflow/user/ops/elu_op.cpp @@ -19,7 +19,7 @@ limitations under the License. namespace oneflow { /* static */ Maybe EluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); return Maybe::Ok(); } @@ -43,7 +43,7 @@ namespace oneflow { /* static */ Maybe EluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); + Shape* dx_shape = ctx->MutOutputShape("dx", 0); CHECK_OR_RETURN(dy_shape == x_shape); *dx_shape = dy_shape; return Maybe::Ok(); diff --git a/oneflow/user/ops/embedding_op.cpp b/oneflow/user/ops/embedding_op.cpp index 5d124cac674..ab3a0960519 100644 --- a/oneflow/user/ops/embedding_op.cpp +++ b/oneflow/user/ops/embedding_op.cpp @@ -20,7 +20,7 @@ limitations under the License. namespace oneflow { /* static */ Maybe EmbeddingRenormOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/empty_op.cpp b/oneflow/user/ops/empty_op.cpp index 92582ad145d..958843bdb03 100644 --- a/oneflow/user/ops/empty_op.cpp +++ b/oneflow/user/ops/empty_op.cpp @@ -38,8 +38,8 @@ Maybe> MakeEmptyStream(const Symbol& out_device, const bo } // namespace /* static */ Maybe EmptyOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = Shape(ctx->Attr("shape").dim_vec()); - *ctx->OutputStride("out", 0) = Stride(Shape(ctx->Attr("shape").dim_vec())); + *ctx->MutOutputShape("out", 0) = Shape(ctx->Attr("shape").dim_vec()); + *ctx->MutOutputStride("out", 0) = Stride(Shape(ctx->Attr("shape").dim_vec())); return Maybe::Ok(); } @@ -52,8 +52,8 @@ Maybe> MakeEmptyStream(const Symbol& out_device, const bo GetTensorSliceView4ParallelId(parallel_hierarchy, nd_sbp, logical_shape, parallel_id); const Shape& physical_shape = tensor_slice_view.shape(); - *ctx->OutputShape("out", 0) = physical_shape; - *ctx->OutputStride("out", 0) = Stride(physical_shape); + *ctx->MutOutputShape("out", 0) = physical_shape; + *ctx->MutOutputStride("out", 0) = Stride(physical_shape); return Maybe::Ok(); } diff --git a/oneflow/user/ops/erfinv_op.cpp b/oneflow/user/ops/erfinv_op.cpp index 708e50c89c6..a0467942a39 100644 --- a/oneflow/user/ops/erfinv_op.cpp +++ b/oneflow/user/ops/erfinv_op.cpp @@ -20,7 +20,7 @@ namespace oneflow { /* static */ Maybe ErfInvOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); - Shape* y_shape = ctx->OutputShape("y", 0); + Shape* y_shape = ctx->MutOutputShape("y", 0); *y_shape = x_shape; return Maybe::Ok(); } diff --git a/oneflow/user/ops/expand_dims_op.cpp b/oneflow/user/ops/expand_dims_op.cpp index f5031f7a1b3..79392e43258 100644 --- a/oneflow/user/ops/expand_dims_op.cpp +++ b/oneflow/user/ops/expand_dims_op.cpp @@ -31,7 +31,7 @@ int32_t TransformNegativeAxisToPositive(int32_t axis, const int32_t num_axes) { /* static */ Maybe ExpandDimsOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& in_shape = ctx->InputShape("in", 0); - Shape* out_shape = ctx->OutputShape("out", 0); + Shape* out_shape = ctx->MutOutputShape("out", 0); const int32_t axis = TransformNegativeAxisToPositive(ctx->Attr("axis"), in_shape.NumAxes()); diff --git a/oneflow/user/ops/expand_op.cpp b/oneflow/user/ops/expand_op.cpp index 9e8cfd5c2ef..8837793c7a1 100644 --- a/oneflow/user/ops/expand_op.cpp +++ b/oneflow/user/ops/expand_op.cpp @@ -32,7 +32,7 @@ namespace oneflow { std::vector stride; CHECK_JUST(getOutShapeAndStrideForFp(in_shape, logical_expand_shape, out_shape, stride)); - Shape* output_shape = ctx->OutputShape("out", 0); + Shape* output_shape = ctx->MutOutputShape("out", 0); DimVector dim_vec(out_shape.begin(), out_shape.end()); *output_shape = Shape(dim_vec); @@ -90,7 +90,7 @@ namespace oneflow { CHECK_JUST(getOutShapeAndStrideForBp(logical_out_shape, logical_expand_shape, in_shape, out_shape, stride)); - Shape* output_shape = ctx->OutputShape("out", 0); + Shape* output_shape = ctx->MutOutputShape("out", 0); DimVector dim_vec(out_shape.begin(), out_shape.end()); *output_shape = Shape(dim_vec); return Maybe::Ok(); diff --git a/oneflow/user/ops/eye_op.cpp b/oneflow/user/ops/eye_op.cpp index 077758b2452..69823ff7943 100644 --- a/oneflow/user/ops/eye_op.cpp +++ b/oneflow/user/ops/eye_op.cpp @@ -21,7 +21,7 @@ namespace oneflow { /* static */ Maybe EyeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { int64_t rows = ctx->Attr("rows"); int64_t cols = ctx->Attr("cols"); - *ctx->OutputShape("out", 0) = Shape({rows, cols}); + *ctx->MutOutputShape("out", 0) = Shape({rows, cols}); return Maybe::Ok(); } diff --git a/oneflow/user/ops/fake_quantization_op.cpp b/oneflow/user/ops/fake_quantization_op.cpp index fbe6a7d8ca6..bc6dfe54a4b 100644 --- a/oneflow/user/ops/fake_quantization_op.cpp +++ b/oneflow/user/ops/fake_quantization_op.cpp @@ -30,7 +30,7 @@ namespace oneflow { CHECK_EQ_OR_RETURN(zero_point_shape.elem_cnt(), in_shape.At(0)); } - *ctx->OutputShape("out", 0) = in_shape; + *ctx->MutOutputShape("out", 0) = in_shape; return Maybe::Ok(); } diff --git a/oneflow/user/ops/fill_op.cpp b/oneflow/user/ops/fill_op.cpp index 854e9a311e7..064dd54a80c 100644 --- a/oneflow/user/ops/fill_op.cpp +++ b/oneflow/user/ops/fill_op.cpp @@ -20,9 +20,9 @@ namespace oneflow { /* static */ Maybe FillOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& in_shape = ctx->InputShape("in", 0); - Shape* out_shape = ctx->OutputShape("out", 0); + Shape* out_shape = ctx->MutOutputShape("out", 0); *out_shape = in_shape; - Stride* out_stride = ctx->OutputStride("out", 0); + Stride* out_stride = ctx->MutOutputStride("out", 0); *out_stride = ctx->InputStride("in", 0); return Maybe::Ok(); } @@ -46,9 +46,9 @@ namespace oneflow { /* static */ Maybe FillTensorOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& in_shape = ctx->InputShape("in", 0); - Shape* out_shape = ctx->OutputShape("out", 0); + Shape* out_shape = ctx->MutOutputShape("out", 0); *out_shape = in_shape; - Stride* out_stride = ctx->OutputStride("out", 0); + Stride* out_stride = ctx->MutOutputStride("out", 0); *out_stride = ctx->InputStride("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/fused_bias_add_op.cpp b/oneflow/user/ops/fused_bias_add_op.cpp index 46f9394ff18..378e9ed50fe 100644 --- a/oneflow/user/ops/fused_bias_add_op.cpp +++ b/oneflow/user/ops/fused_bias_add_op.cpp @@ -27,7 +27,7 @@ namespace oneflow { CHECK_GE_OR_RETURN(bias_add_axis, 0); CHECK_LT_OR_RETURN(bias_add_axis, a_tensor_desc.shape().NumAxes()); CHECK_EQ_OR_RETURN(a_tensor_desc.shape().At(bias_add_axis), b_tensor_desc.shape().At(0)); - *ctx->OutputShape("out", 0) = a_tensor_desc.shape(); + *ctx->MutOutputShape("out", 0) = a_tensor_desc.shape(); *ctx->OutputIsDynamic("out", 0) = a_tensor_desc.is_dynamic(); return Maybe::Ok(); } @@ -67,7 +67,7 @@ namespace oneflow { CHECK_GE_OR_RETURN(bias_add_axis, 0); CHECK_LT_OR_RETURN(bias_add_axis, a_tensor_desc.shape().NumAxes()); CHECK_EQ_OR_RETURN(a_tensor_desc.shape().At(bias_add_axis), b_tensor_desc.shape().At(0)); - *ctx->OutputShape("dx", 0) = a_tensor_desc.shape(); + *ctx->MutOutputShape("dx", 0) = a_tensor_desc.shape(); *ctx->OutputIsDynamic("dx", 0) = a_tensor_desc.is_dynamic(); return Maybe::Ok(); } @@ -152,7 +152,7 @@ REGISTER_USER_OP_GRAD("fused_bias_add_gelu") CHECK_LT_OR_RETURN(bias_add_axis, a_tensor_desc.shape().NumAxes()); CHECK_EQ_OR_RETURN(a_tensor_desc.shape().At(bias_add_axis), b_tensor_desc.shape().At(0)); CHECK_EQ_OR_RETURN(a_tensor_desc.shape(), mask_tensor_desc.shape()); - *ctx->OutputShape("out", 0) = a_tensor_desc.shape(); + *ctx->MutOutputShape("out", 0) = a_tensor_desc.shape(); *ctx->OutputIsDynamic("out", 0) = a_tensor_desc.is_dynamic(); return Maybe::Ok(); } diff --git a/oneflow/user/ops/fused_cross_feature_interaction_op.cpp b/oneflow/user/ops/fused_cross_feature_interaction_op.cpp index 5486fc9634a..ca140295ac6 100644 --- a/oneflow/user/ops/fused_cross_feature_interaction_op.cpp +++ b/oneflow/user/ops/fused_cross_feature_interaction_op.cpp @@ -24,11 +24,11 @@ namespace oneflow { const Shape& x_shape = ctx->InputShape("x", 0); const Shape& weight_shape = ctx->InputShape("weight", 0); CHECK_EQ_OR_RETURN(x_shape.At(1), weight_shape.At(1)) << "Matmul K dims should be equal. "; - *ctx->OutputShape("matmul_result", 0) = Shape({x_shape.At(0), weight_shape.At(0)}); + *ctx->MutOutputShape("matmul_result", 0) = Shape({x_shape.At(0), weight_shape.At(0)}); const Shape& x0_shape = ctx->InputShape("x0", 0); const Shape& bias_shape = ctx->InputShape("bias", 0); CHECK_EQ_OR_RETURN(bias_shape.At(0), x0_shape.At(1)) << "Bias dim should be equal to X0 dim1. "; - *ctx->OutputShape("out", 0) = x0_shape; + *ctx->MutOutputShape("out", 0) = x0_shape; return Maybe::Ok(); } @@ -59,10 +59,10 @@ namespace oneflow { user_op::InferContext* ctx) { const Shape& x0_shape = ctx->InputShape("x0", 0); const Shape& weight_shape = ctx->InputShape("weight", 0); - *ctx->OutputShape("dx0", 0) = x0_shape; - *ctx->OutputShape("dw", 0) = weight_shape; - *ctx->OutputShape("dx", 0) = x0_shape; - *ctx->OutputShape("dbias", 0) = Shape({x0_shape.At(1)}); + *ctx->MutOutputShape("dx0", 0) = x0_shape; + *ctx->MutOutputShape("dw", 0) = weight_shape; + *ctx->MutOutputShape("dx", 0) = x0_shape; + *ctx->MutOutputShape("dbias", 0) = Shape({x0_shape.At(1)}); return Maybe::Ok(); } @@ -100,10 +100,10 @@ namespace oneflow { user_op::InferContext* ctx) { const Shape& x0_shape = ctx->InputShape("x0", 0); const Shape& weight_shape = ctx->InputShape("weight", 0); - *ctx->OutputShape("dx0", 0) = x0_shape; - *ctx->OutputShape("dw", 0) = weight_shape; - *ctx->OutputShape("dx", 0) = x0_shape; - *ctx->OutputShape("dbias", 0) = Shape({x0_shape.At(1)}); + *ctx->MutOutputShape("dx0", 0) = x0_shape; + *ctx->MutOutputShape("dw", 0) = weight_shape; + *ctx->MutOutputShape("dx", 0) = x0_shape; + *ctx->MutOutputShape("dbias", 0) = Shape({x0_shape.At(1)}); return Maybe::Ok(); } diff --git a/oneflow/user/ops/fused_dot_feature_interaction_op.cpp b/oneflow/user/ops/fused_dot_feature_interaction_op.cpp index 0d99cf8b489..da1d256eb67 100644 --- a/oneflow/user/ops/fused_dot_feature_interaction_op.cpp +++ b/oneflow/user/ops/fused_dot_feature_interaction_op.cpp @@ -36,7 +36,7 @@ namespace oneflow { } const std::string& pooling = ctx->Attr("pooling"); if (pooling == "sum") { - *ctx->OutputShape("out", 0) = Shape({batch_size, vector_size}); + *ctx->MutOutputShape("out", 0) = Shape({batch_size, vector_size}); return Maybe::Ok(); } if (ctx->has_input("sparse_feature", 0)) { @@ -66,7 +66,7 @@ namespace oneflow { CHECK_EQ_OR_RETURN(output_concat_shape.At(0), batch_size); out_dim += output_concat_shape.At(1); } - *ctx->OutputShape("out", 0) = Shape({batch_size, out_dim}); + *ctx->MutOutputShape("out", 0) = Shape({batch_size, out_dim}); return Maybe::Ok(); } @@ -109,14 +109,14 @@ namespace oneflow { CHECK_EQ_OR_RETURN(ctx->output_size("features_grad"), ctx->input_size("features")) << "features_grad and features must have same size"; for (int64_t i = 0; i < ctx->output_size("features_grad"); ++i) { - *ctx->OutputShape("features_grad", i) = ctx->InputShape("features", i); + *ctx->MutOutputShape("features_grad", i) = ctx->InputShape("features", i); } if (ctx->has_output("output_concat_grad", 0)) { const int32_t output_concat_grad_dim = ctx->Attr("output_concat_grad_dim"); - *ctx->OutputShape("output_concat_grad", 0) = Shape({batch_size, output_concat_grad_dim}); + *ctx->MutOutputShape("output_concat_grad", 0) = Shape({batch_size, output_concat_grad_dim}); } if (ctx->has_output("sparse_feature_grad", 0)) { - *ctx->OutputShape("sparse_feature_grad", 0) = ctx->InputShape("sparse_feature", 0); + *ctx->MutOutputShape("sparse_feature_grad", 0) = ctx->InputShape("sparse_feature", 0); } return Maybe::Ok(); } diff --git a/oneflow/user/ops/fused_gru_cell_op.cpp b/oneflow/user/ops/fused_gru_cell_op.cpp index b9b6b7063f1..7b3aaee0e31 100644 --- a/oneflow/user/ops/fused_gru_cell_op.cpp +++ b/oneflow/user/ops/fused_gru_cell_op.cpp @@ -21,8 +21,8 @@ namespace oneflow { /* static */ Maybe FusedGruCellOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& hx_shape = ctx->InputShape("hx", 0); - *ctx->OutputShape("hy", 0) = hx_shape; - *ctx->OutputShape("workspace", 0) = Shape({hx_shape.At(0), hx_shape.At(1) * 5}); + *ctx->MutOutputShape("hy", 0) = hx_shape; + *ctx->MutOutputShape("workspace", 0) = Shape({hx_shape.At(0), hx_shape.At(1) * 5}); return Maybe::Ok(); } @@ -69,14 +69,14 @@ namespace oneflow { /* static */ Maybe FusedGruCellGradOp ::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& grad_hy_shape = ctx->InputShape("grad_hy", 0); DimVector dim_vec({grad_hy_shape.At(0), grad_hy_shape.At(1) * 3}); - *ctx->OutputShape("grad_input_gates", 0) = Shape(dim_vec); - *ctx->OutputShape("grad_hidden_gates", 0) = Shape(dim_vec); + *ctx->MutOutputShape("grad_input_gates", 0) = Shape(dim_vec); + *ctx->MutOutputShape("grad_hidden_gates", 0) = Shape(dim_vec); - if (ctx->has_output("grad_hx", 0)) { *ctx->OutputShape("grad_hx", 0) = grad_hy_shape; } + if (ctx->has_output("grad_hx", 0)) { *ctx->MutOutputShape("grad_hx", 0) = grad_hy_shape; } if (ctx->has_output("grad_input_bias", 0) && ctx->has_output("grad_hidden_bias", 0)) { - *ctx->OutputShape("grad_input_bias", 0) = Shape({grad_hy_shape.At(1) * 3}); - *ctx->OutputShape("grad_hidden_bias", 0) = Shape({grad_hy_shape.At(1) * 3}); + *ctx->MutOutputShape("grad_input_bias", 0) = Shape({grad_hy_shape.At(1) * 3}); + *ctx->MutOutputShape("grad_hidden_bias", 0) = Shape({grad_hy_shape.At(1) * 3}); } return Maybe::Ok(); diff --git a/oneflow/user/ops/fused_lstm_cell_op.cpp b/oneflow/user/ops/fused_lstm_cell_op.cpp index 5ce8add4f7b..8cf2663e04c 100644 --- a/oneflow/user/ops/fused_lstm_cell_op.cpp +++ b/oneflow/user/ops/fused_lstm_cell_op.cpp @@ -21,9 +21,9 @@ namespace oneflow { /* static */ Maybe FusedLstmCellOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& cx_shape = ctx->InputShape("cx", 0); - *ctx->OutputShape("hy", 0) = cx_shape; - *ctx->OutputShape("cy", 0) = cx_shape; - *ctx->OutputShape("workspace", 0) = ctx->InputShape("input_gates", 0); + *ctx->MutOutputShape("hy", 0) = cx_shape; + *ctx->MutOutputShape("cy", 0) = cx_shape; + *ctx->MutOutputShape("workspace", 0) = ctx->InputShape("input_gates", 0); return Maybe::Ok(); } @@ -71,12 +71,14 @@ namespace oneflow { } /* static */ Maybe FusedLstmCellGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("grad_gates", 0) = ctx->InputShape("workspace", 0); + *ctx->MutOutputShape("grad_gates", 0) = ctx->InputShape("workspace", 0); - if (ctx->has_output("grad_cx", 0)) { *ctx->OutputShape("grad_cx", 0) = ctx->InputShape("cx", 0); } + if (ctx->has_output("grad_cx", 0)) { + *ctx->MutOutputShape("grad_cx", 0) = ctx->InputShape("cx", 0); + } if (ctx->has_output("grad_bias", 0)) { - *ctx->OutputShape("grad_bias", 0) = Shape({ctx->InputShape("workspace", 0).At(1)}); + *ctx->MutOutputShape("grad_bias", 0) = Shape({ctx->InputShape("workspace", 0).At(1)}); } return Maybe::Ok(); diff --git a/oneflow/user/ops/fused_matmul_bias_add_relu_dropout_op.cpp b/oneflow/user/ops/fused_matmul_bias_add_relu_dropout_op.cpp index c473ba7ea57..ced41d69fd8 100644 --- a/oneflow/user/ops/fused_matmul_bias_add_relu_dropout_op.cpp +++ b/oneflow/user/ops/fused_matmul_bias_add_relu_dropout_op.cpp @@ -65,12 +65,12 @@ Maybe InferTensorDesc4FusedMatmul(user_op::InferContext* ctx) { // Set Middle result shape. long cublas_aligned_aux_ld = AlignReluAuxLd(cublas_aux_ld); int64_t aux_size = cublas_aligned_aux_ld / 32; // Cause we use int32_t as dtype - *ctx->OutputShape("cublas_aux", idx) = Shape({m, aux_size}); - *ctx->OutputShape("hidden", idx) = Shape({m, n}); + *ctx->MutOutputShape("cublas_aux", idx) = Shape({m, aux_size}); + *ctx->MutOutputShape("hidden", idx) = Shape({m, n}); // Set for next layer. k = n; } - *ctx->OutputShape("out", 0) = {m, n}; + *ctx->MutOutputShape("out", 0) = {m, n}; return Maybe::Ok(); } diff --git a/oneflow/user/ops/fused_relu_dropout_grad_op.cpp b/oneflow/user/ops/fused_relu_dropout_grad_op.cpp index 14101dd16c5..5de869d6a45 100644 --- a/oneflow/user/ops/fused_relu_dropout_grad_op.cpp +++ b/oneflow/user/ops/fused_relu_dropout_grad_op.cpp @@ -25,7 +25,7 @@ namespace oneflow { namespace { Maybe InferTensorDesc4FusedReluDropoutGrad(user_op::InferContext* ctx) { - *ctx->OutputShape("dx", 0) = ctx->InputShape("dy", 0); + *ctx->MutOutputShape("dx", 0) = ctx->InputShape("dy", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/fused_scale_mask_softmax_dropout_op.cpp b/oneflow/user/ops/fused_scale_mask_softmax_dropout_op.cpp index eabeed57b06..0d9973a79fb 100644 --- a/oneflow/user/ops/fused_scale_mask_softmax_dropout_op.cpp +++ b/oneflow/user/ops/fused_scale_mask_softmax_dropout_op.cpp @@ -27,9 +27,9 @@ namespace oneflow { CHECK_EQ_OR_RETURN(x_desc.shape().At(x_shape.NumAxes() - 1), mask_desc.shape().At(mask_shape.NumAxes() - 1)) << " last dim of x and mask is not equal."; - *ctx->OutputShape("y", 0) = x_desc.shape(); + *ctx->MutOutputShape("y", 0) = x_desc.shape(); *ctx->OutputIsDynamic("y", 0) = x_desc.is_dynamic(); - *ctx->OutputShape("softmax_y", 0) = x_desc.shape(); + *ctx->MutOutputShape("softmax_y", 0) = x_desc.shape(); *ctx->OutputIsDynamic("softmax_y", 0) = x_desc.is_dynamic(); return Maybe::Ok(); } diff --git a/oneflow/user/ops/fused_scale_mask_softmax_op.cpp b/oneflow/user/ops/fused_scale_mask_softmax_op.cpp index 235e897db47..d8d6ceda8f7 100644 --- a/oneflow/user/ops/fused_scale_mask_softmax_op.cpp +++ b/oneflow/user/ops/fused_scale_mask_softmax_op.cpp @@ -27,7 +27,7 @@ namespace oneflow { CHECK_EQ_OR_RETURN(x_desc.shape().At(x_shape.NumAxes() - 1), mask_desc.shape().At(mask_shape.NumAxes() - 1)) << " last dim of x and mask is not equal."; - *ctx->OutputShape("y", 0) = x_desc.shape(); + *ctx->MutOutputShape("y", 0) = x_desc.shape(); *ctx->OutputIsDynamic("y", 0) = x_desc.is_dynamic(); return Maybe::Ok(); } diff --git a/oneflow/user/ops/fused_scale_tril_softmax_mask_scale_op.cpp b/oneflow/user/ops/fused_scale_tril_softmax_mask_scale_op.cpp index 20dead6c8d7..77dd85f57a4 100644 --- a/oneflow/user/ops/fused_scale_tril_softmax_mask_scale_op.cpp +++ b/oneflow/user/ops/fused_scale_tril_softmax_mask_scale_op.cpp @@ -20,9 +20,9 @@ namespace oneflow { /*static*/ auto FusedTrilScaleSoftmaxMaskScaleOp::InferLogicalTensorDesc(user_op::InferContext* ctx) -> Maybe { const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); - *ctx->OutputShape("y", 0) = x_desc.shape(); + *ctx->MutOutputShape("y", 0) = x_desc.shape(); *ctx->OutputIsDynamic("y", 0) = x_desc.is_dynamic(); - *ctx->OutputShape("softmax_y", 0) = x_desc.shape(); + *ctx->MutOutputShape("softmax_y", 0) = x_desc.shape(); *ctx->OutputIsDynamic("softmax_y", 0) = x_desc.is_dynamic(); return Maybe::Ok(); } diff --git a/oneflow/user/ops/fused_self_attention_query_mul_key_and_value_ops.cpp b/oneflow/user/ops/fused_self_attention_query_mul_key_and_value_ops.cpp index 232a78189c9..4afaa120388 100644 --- a/oneflow/user/ops/fused_self_attention_query_mul_key_and_value_ops.cpp +++ b/oneflow/user/ops/fused_self_attention_query_mul_key_and_value_ops.cpp @@ -41,8 +41,8 @@ namespace oneflow { CHECK_EQ_OR_RETURN(hidden_size % (head_size * 3), 0); int64_t num_heads = hidden_size / (head_size * 3); - *ctx->OutputShape("query_mul_key", 0) = Shape({batch_size, num_heads, seq_len, seq_len}); - *ctx->OutputShape("value", 0) = Shape({batch_size, num_heads, seq_len, head_size}); + *ctx->MutOutputShape("query_mul_key", 0) = Shape({batch_size, num_heads, seq_len, seq_len}); + *ctx->MutOutputShape("value", 0) = Shape({batch_size, num_heads, seq_len, head_size}); return Maybe::Ok(); } @@ -98,7 +98,7 @@ namespace oneflow { CHECK_EQ_OR_RETURN(qmk_grad_shape.At(2), seq_len); CHECK_EQ_OR_RETURN(qmk_grad_shape.At(3), seq_len); - *ctx->OutputShape("hidden_states_grad", 0) = h_shape; + *ctx->MutOutputShape("hidden_states_grad", 0) = h_shape; return Maybe::Ok(); } /*static*/ auto FusedSelfAttentionQueryMulKeyAndValueGradOp::InferPhysicalTensorDesc( diff --git a/oneflow/user/ops/gelu_op.cpp b/oneflow/user/ops/gelu_op.cpp index 39f12592c23..50c2012c83e 100644 --- a/oneflow/user/ops/gelu_op.cpp +++ b/oneflow/user/ops/gelu_op.cpp @@ -20,7 +20,7 @@ namespace oneflow { /*static*/ auto GeluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) -> Maybe { const Shape& in_shape = ctx->InputShape("in", 0); - Shape* out_shape = ctx->OutputShape("out", 0); + Shape* out_shape = ctx->MutOutputShape("out", 0); *out_shape = in_shape; return Maybe::Ok(); } @@ -42,7 +42,7 @@ namespace oneflow { /*static*/ auto GeluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) -> Maybe { const Shape& x_shape = ctx->InputShape("x", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); + Shape* dx_shape = ctx->MutOutputShape("dx", 0); CHECK_OR_RETURN(dy_shape == x_shape); *dx_shape = dy_shape; return Maybe::Ok(); diff --git a/oneflow/user/ops/generate_random_batch_permutation_indices_op.cpp b/oneflow/user/ops/generate_random_batch_permutation_indices_op.cpp index 73b7dcb52eb..7d929383f99 100644 --- a/oneflow/user/ops/generate_random_batch_permutation_indices_op.cpp +++ b/oneflow/user/ops/generate_random_batch_permutation_indices_op.cpp @@ -21,7 +21,7 @@ namespace oneflow { /*static*/ auto GenerateRandomBatchPermutationIndicesOp::InferLogicalTensorDesc( user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("y", 0) = Shape({ctx->InputShape("x", 0).At(0)}); + *ctx->MutOutputShape("y", 0) = Shape({ctx->InputShape("x", 0).At(0)}); return Maybe::Ok(); } /*static*/ auto GenerateRandomBatchPermutationIndicesOp::InferPhysicalTensorDesc( diff --git a/oneflow/user/ops/hardshrink_op.cpp b/oneflow/user/ops/hardshrink_op.cpp index 21fdae26a17..362818758b3 100644 --- a/oneflow/user/ops/hardshrink_op.cpp +++ b/oneflow/user/ops/hardshrink_op.cpp @@ -19,7 +19,7 @@ limitations under the License. namespace oneflow { /* static */ Maybe HardShrinkOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); return Maybe::Ok(); } @@ -43,7 +43,7 @@ namespace oneflow { /* static */ Maybe HardShrinkGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& y_shape = ctx->InputShape("y", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); + Shape* dx_shape = ctx->MutOutputShape("dx", 0); CHECK_OR_RETURN(dy_shape == y_shape) << "The shape of y_grad and y must be same."; *dx_shape = dy_shape; return Maybe::Ok(); diff --git a/oneflow/user/ops/hardsigmoid_op.cpp b/oneflow/user/ops/hardsigmoid_op.cpp index 887614425ac..f56d3392058 100644 --- a/oneflow/user/ops/hardsigmoid_op.cpp +++ b/oneflow/user/ops/hardsigmoid_op.cpp @@ -20,7 +20,7 @@ namespace oneflow { /* static */ Maybe HardsigmoidOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& in_shape = ctx->InputShape("in", 0); - Shape* out_shape = ctx->OutputShape("out", 0); + Shape* out_shape = ctx->MutOutputShape("out", 0); *out_shape = in_shape; return Maybe::Ok(); } @@ -45,7 +45,7 @@ namespace oneflow { /* static */ Maybe HardsigmoidGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); + Shape* dx_shape = ctx->MutOutputShape("dx", 0); CHECK_OR_RETURN(dy_shape == x_shape); *dx_shape = dy_shape; return Maybe::Ok(); diff --git a/oneflow/user/ops/hardswish_op.cpp b/oneflow/user/ops/hardswish_op.cpp index f7dfbc5c870..3342e1d4dbb 100644 --- a/oneflow/user/ops/hardswish_op.cpp +++ b/oneflow/user/ops/hardswish_op.cpp @@ -19,7 +19,7 @@ limitations under the License. namespace oneflow { /* static */ Maybe HardswishOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); return Maybe::Ok(); } @@ -43,7 +43,7 @@ namespace oneflow { /* static */ Maybe HardswishGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); + Shape* dx_shape = ctx->MutOutputShape("dx", 0); CHECK_OR_RETURN(dy_shape == x_shape); *dx_shape = dy_shape; return Maybe::Ok(); diff --git a/oneflow/user/ops/hardtanh_op.cpp b/oneflow/user/ops/hardtanh_op.cpp index 2d5208c7b0b..d2033b79870 100644 --- a/oneflow/user/ops/hardtanh_op.cpp +++ b/oneflow/user/ops/hardtanh_op.cpp @@ -20,7 +20,7 @@ namespace oneflow { /* static */ Maybe HardtanhOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& in_shape = ctx->InputShape("in", 0); - Shape* out_shape = ctx->OutputShape("out", 0); + Shape* out_shape = ctx->MutOutputShape("out", 0); *out_shape = in_shape; double min_val = ctx->Attr("min_val"); double max_val = ctx->Attr("max_val"); @@ -48,7 +48,7 @@ namespace oneflow { /* static */ Maybe HardtanhGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& y_shape = ctx->InputShape("y", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); + Shape* dx_shape = ctx->MutOutputShape("dx", 0); CHECK_OR_RETURN(dy_shape == y_shape); *dx_shape = dy_shape; double min_val = ctx->Attr("min_val"); diff --git a/oneflow/user/ops/hierarchical_parallel_cast_op.cpp b/oneflow/user/ops/hierarchical_parallel_cast_op.cpp index 7ddad5a603f..564960b6e66 100644 --- a/oneflow/user/ops/hierarchical_parallel_cast_op.cpp +++ b/oneflow/user/ops/hierarchical_parallel_cast_op.cpp @@ -21,7 +21,7 @@ namespace oneflow { /* static */ Maybe HierarchicalParallelCastOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } @@ -57,7 +57,7 @@ namespace oneflow { /* static */ Maybe HierarchicalParallelCastLikeOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/identity_op.cpp b/oneflow/user/ops/identity_op.cpp index 538abeb5dde..10deb96ce54 100644 --- a/oneflow/user/ops/identity_op.cpp +++ b/oneflow/user/ops/identity_op.cpp @@ -19,7 +19,7 @@ limitations under the License. namespace oneflow { /* static */ Maybe IdentityOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/image_object_preprocess_ops.cpp b/oneflow/user/ops/image_object_preprocess_ops.cpp index 5fd2cb99f38..d2b523ec994 100644 --- a/oneflow/user/ops/image_object_preprocess_ops.cpp +++ b/oneflow/user/ops/image_object_preprocess_ops.cpp @@ -35,7 +35,7 @@ Maybe ImageObjectGetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& flip_code_desc = ctx->InputTensorDesc("flip_code", 0); CHECK_EQ_OR_RETURN(flip_code_desc.shape().elem_cnt(), N); - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } @@ -66,7 +66,7 @@ Maybe ImageObjectGetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& flip_code_desc = ctx->InputTensorDesc("flip_code", 0); CHECK_EQ_OR_RETURN(flip_code_desc.shape().elem_cnt(), N); - *ctx->OutputShape("out", 0) = ctx->InputShape("bbox", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("bbox", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("bbox", 0); return Maybe::Ok(); } @@ -98,7 +98,7 @@ Maybe ImageObjectGetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& scale_desc = ctx->InputTensorDesc("scale", 0); CHECK_EQ_OR_RETURN(scale_desc.shape().elem_cnt(), N * 2); - *ctx->OutputShape("out", 0) = ctx->InputShape("bbox", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("bbox", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("bbox", 0); return Maybe::Ok(); } @@ -132,7 +132,7 @@ Maybe ImageObjectGetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& flip_code_desc = ctx->InputTensorDesc("flip_code", 0); CHECK_EQ_OR_RETURN(flip_code_desc.shape().elem_cnt(), N); - *ctx->OutputShape("out", 0) = ctx->InputShape("poly", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("poly", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("poly", 0); return Maybe::Ok(); } @@ -167,7 +167,7 @@ Maybe ImageObjectGetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& scale_desc = ctx->InputTensorDesc("scale", 0); CHECK_EQ_OR_RETURN(scale_desc.shape().elem_cnt(), N * 2); - *ctx->OutputShape("out", 0) = ctx->InputShape("poly", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("poly", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("poly", 0); return Maybe::Ok(); } @@ -194,7 +194,7 @@ Maybe ImageObjectGetSbp(user_op::SbpContext* ctx) { /* static */ Maybe ImageNormalizeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); CHECK_EQ_OR_RETURN(in_desc.shape().NumAxes(), 1); - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } @@ -227,7 +227,7 @@ Maybe ImageObjectGetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& image_size_desc = ctx->InputTensorDesc("image_size", 0); CHECK_EQ_OR_RETURN(image_size_desc.shape().elem_cnt(), N * 2); - *ctx->OutputShape("out", 0) = ctx->InputShape("poly", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("poly", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("poly", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/image_preprocess_ops.cpp b/oneflow/user/ops/image_preprocess_ops.cpp index 00c6d419c8b..20985964a94 100644 --- a/oneflow/user/ops/image_preprocess_ops.cpp +++ b/oneflow/user/ops/image_preprocess_ops.cpp @@ -159,7 +159,7 @@ namespace oneflow { const auto tensor_slice_view = GetTensorSliceView4ParallelId(parallel_hierarchy, nd_sbp, logical_shape, parallel_id); const Shape& physical_shape = tensor_slice_view.shape(); - *ctx->OutputShape("out", 0) = physical_shape; + *ctx->MutOutputShape("out", 0) = physical_shape; return Maybe::Ok(); } diff --git a/oneflow/user/ops/l1_l2_regularize_gradient_op.cpp b/oneflow/user/ops/l1_l2_regularize_gradient_op.cpp index 05affa22404..7b57a21bd01 100644 --- a/oneflow/user/ops/l1_l2_regularize_gradient_op.cpp +++ b/oneflow/user/ops/l1_l2_regularize_gradient_op.cpp @@ -24,7 +24,7 @@ Maybe InferTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& model = ctx->InputTensorDesc("model", 0); const user_op::TensorDesc& model_diff = ctx->InputTensorDesc("model_diff", 0); CHECK_EQ_OR_RETURN(model_diff.shape(), model.shape()); - *ctx->OutputShape("out", 0) = ctx->InputShape("model", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("model", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("model", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/l2_normalize_op.cpp b/oneflow/user/ops/l2_normalize_op.cpp index d1723c41c97..4fed45fad79 100644 --- a/oneflow/user/ops/l2_normalize_op.cpp +++ b/oneflow/user/ops/l2_normalize_op.cpp @@ -20,8 +20,8 @@ namespace oneflow { /* static */ Maybe L2NormalizeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); - Shape* y_shape = ctx->OutputShape("y", 0); - Shape* square_x_sum_shape = ctx->OutputShape("square_x_sum", 0); + Shape* y_shape = ctx->MutOutputShape("y", 0); + Shape* square_x_sum_shape = ctx->MutOutputShape("square_x_sum", 0); const int32_t axis = ctx->Attr("axis"); const float epsilon = ctx->Attr("epsilon"); CHECK_GE_OR_RETURN(axis, 0); @@ -62,7 +62,7 @@ namespace oneflow { const Shape& dy_shape = ctx->InputShape("dy", 0); const Shape& y_shape = ctx->InputShape("y", 0); const Shape& square_x_sum_shape = ctx->InputShape("square_x_sum", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); + Shape* dx_shape = ctx->MutOutputShape("dx", 0); const int32_t axis = ctx->Attr("axis"); const float epsilon = ctx->Attr("epsilon"); CHECK_EQ_OR_RETURN(dy_shape, y_shape); diff --git a/oneflow/user/ops/leaky_relu_op.cpp b/oneflow/user/ops/leaky_relu_op.cpp index 09d8b318c54..fb43e8a2bf2 100644 --- a/oneflow/user/ops/leaky_relu_op.cpp +++ b/oneflow/user/ops/leaky_relu_op.cpp @@ -20,7 +20,7 @@ namespace oneflow { /* static */ Maybe LeakyReluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); - Shape* y_shape = ctx->OutputShape("y", 0); + Shape* y_shape = ctx->MutOutputShape("y", 0); *y_shape = x_shape; return Maybe::Ok(); } @@ -45,7 +45,7 @@ namespace oneflow { /* static */ Maybe LeakyReluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); + Shape* dx_shape = ctx->MutOutputShape("dx", 0); CHECK_OR_RETURN(dy_shape == x_shape); *dx_shape = dy_shape; return Maybe::Ok(); diff --git a/oneflow/user/ops/log_softmax_op.cpp b/oneflow/user/ops/log_softmax_op.cpp index d8cffbf7460..8064d78941c 100644 --- a/oneflow/user/ops/log_softmax_op.cpp +++ b/oneflow/user/ops/log_softmax_op.cpp @@ -19,7 +19,7 @@ limitations under the License. namespace oneflow { /* static */ Maybe LogSoftmaxOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("prob", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("prob", 0) = ctx->InputShape("in", 0); return Maybe::Ok(); } @@ -46,7 +46,7 @@ namespace oneflow { /* static */ Maybe LogSoftmaxGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& y_shape = ctx->InputShape("prob", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); + Shape* dx_shape = ctx->MutOutputShape("dx", 0); CHECK_OR_RETURN(dy_shape == y_shape); *dx_shape = dy_shape; return Maybe::Ok(); diff --git a/oneflow/user/ops/masked_fill_op.cpp b/oneflow/user/ops/masked_fill_op.cpp index 327ce994ded..f4cf83edbe5 100644 --- a/oneflow/user/ops/masked_fill_op.cpp +++ b/oneflow/user/ops/masked_fill_op.cpp @@ -22,7 +22,7 @@ namespace { Maybe InferMaskedFillTensorDesc(user_op::InferContext* ctx) { const Shape& mask_shape = ctx->InputShape("mask", 0); - *ctx->OutputShape("out", 0) = mask_shape; + *ctx->MutOutputShape("out", 0) = mask_shape; return Maybe::Ok(); } diff --git a/oneflow/user/ops/math_binary_broadcast_ops.cpp b/oneflow/user/ops/math_binary_broadcast_ops.cpp index 0c4ef770ac3..10ad55d4c0c 100644 --- a/oneflow/user/ops/math_binary_broadcast_ops.cpp +++ b/oneflow/user/ops/math_binary_broadcast_ops.cpp @@ -35,21 +35,21 @@ Maybe InferTensorDescBinaryBroadcastNormal(user_op::InferContext* ctx) { size_t output_num_axes = std::max(tensor_x.shape().NumAxes(), tensor_y.shape().NumAxes()); if (IsZeroDimTensor(&tensor_x)) { - *ctx->OutputShape("z", 0) = ctx->InputShape("y", 0); + *ctx->MutOutputShape("z", 0) = ctx->InputShape("y", 0); *ctx->OutputIsDynamic("z", 0) = ctx->InputIsDynamic("y", 0); } else if (IsZeroDimTensor(&tensor_y)) { - *ctx->OutputShape("z", 0) = ctx->InputShape("x", 0); + *ctx->MutOutputShape("z", 0) = ctx->InputShape("x", 0); *ctx->OutputIsDynamic("z", 0) = ctx->InputIsDynamic("x", 0); } else if (IsScalarTensor(&tensor_x)) { - *ctx->OutputShape("z", 0) = ctx->InputShape("y", 0); + *ctx->MutOutputShape("z", 0) = ctx->InputShape("y", 0); *ctx->OutputIsDynamic("z", 0) = ctx->InputIsDynamic("y", 0); } else if (IsScalarTensor(&tensor_y)) { - *ctx->OutputShape("z", 0) = ctx->InputShape("x", 0); + *ctx->MutOutputShape("z", 0) = ctx->InputShape("x", 0); *ctx->OutputIsDynamic("z", 0) = ctx->InputIsDynamic("x", 0); } else { const auto& x_shape = CreateLeftExtendedShape(ShapeView(tensor_x.shape()), output_num_axes); const auto& y_shape = CreateLeftExtendedShape(ShapeView(tensor_y.shape()), output_num_axes); - *ctx->OutputShape("z", 0) = ctx->InputShape("x", 0); + *ctx->MutOutputShape("z", 0) = ctx->InputShape("x", 0); *ctx->OutputIsDynamic("z", 0) = ctx->InputIsDynamic("x", 0); Shape out_shape(x_shape); FOR_RANGE(int64_t, i, 0, x_shape.NumAxes()) { diff --git a/oneflow/user/ops/matmul_op.cpp b/oneflow/user/ops/matmul_op.cpp index 9604177ed77..9996bd34850 100644 --- a/oneflow/user/ops/matmul_op.cpp +++ b/oneflow/user/ops/matmul_op.cpp @@ -36,7 +36,7 @@ Maybe InferTensorDesc4Matmul(user_op::InferContext* ctx) { user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - *ctx->OutputShape("out", 0) = ctx->InputShape("a", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("a", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("a", 0); int64_t m, n, k; // tensor a (no trans): m*k, tensor b (no trans): k*n diff --git a/oneflow/user/ops/matrix_vector_product_op.cpp b/oneflow/user/ops/matrix_vector_product_op.cpp index 91cfba1224b..fd987d0745b 100644 --- a/oneflow/user/ops/matrix_vector_product_op.cpp +++ b/oneflow/user/ops/matrix_vector_product_op.cpp @@ -26,7 +26,7 @@ Maybe InferTensorDesc4MatrixVectorProduct(user_op::InferContext* ctx) { int64_t m = a.shape().At(0); int64_t k = a.shape().At(1); CHECK_EQ_OR_RETURN(k, b.shape().At(0)) << "Dim K should be equal to vector b's dim0. "; - *ctx->OutputShape("out", 0) = Shape({m}); + *ctx->MutOutputShape("out", 0) = Shape({m}); return Maybe::Ok(); } @@ -47,7 +47,7 @@ Maybe InferTensorDesc4MatrixVectorProductGradA(user_op::InferContext* ctx) const user_op::TensorDesc& b = ctx->InputTensorDesc("b", 0); int64_t m = dy.shape().At(0); int64_t n = b.shape().At(0); - *ctx->OutputShape("dx", 0) = Shape({m, n}); + *ctx->MutOutputShape("dx", 0) = Shape({m, n}); return Maybe::Ok(); } @@ -58,7 +58,7 @@ Maybe InferTensorDesc4MatrixVectorProductGradB(user_op::InferContext* ctx) */ const user_op::TensorDesc& a = ctx->InputTensorDesc("a", 0); int64_t n = a.shape().At(1); - *ctx->OutputShape("dx", 0) = Shape({n}); + *ctx->MutOutputShape("dx", 0) = Shape({n}); return Maybe::Ok(); } diff --git a/oneflow/user/ops/median_op.cpp b/oneflow/user/ops/median_op.cpp index 5ca4689b037..9c80743b588 100644 --- a/oneflow/user/ops/median_op.cpp +++ b/oneflow/user/ops/median_op.cpp @@ -28,7 +28,7 @@ namespace oneflow { } /*static*/ Maybe MedianOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& ones_shape = {1}; - *ctx->OutputShape("output", 0) = ones_shape.RemoveOnes({0}); + *ctx->MutOutputShape("output", 0) = ones_shape.RemoveOnes({0}); return Maybe::Ok(); } /*static*/ Maybe MedianOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { diff --git a/oneflow/user/ops/median_with_indices_op.cpp b/oneflow/user/ops/median_with_indices_op.cpp index d9d0d672735..2aab4ccb8cf 100644 --- a/oneflow/user/ops/median_with_indices_op.cpp +++ b/oneflow/user/ops/median_with_indices_op.cpp @@ -31,8 +31,8 @@ namespace oneflow { } /*static*/ Maybe MedianWithIndicesOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& input_shape = ctx->InputShape("input", 0); - Shape* values_shape = ctx->OutputShape("values", 0); - Shape* indices_shape = ctx->OutputShape("indices", 0); + Shape* values_shape = ctx->MutOutputShape("values", 0); + Shape* indices_shape = ctx->MutOutputShape("indices", 0); const Shape& reduce_shape = CreateReducedShape(input_shape, {-1}); *values_shape = reduce_shape.RemoveOnes({-1}); *indices_shape = reduce_shape.RemoveOnes({-1}); diff --git a/oneflow/user/ops/min_max_observer_op.cpp b/oneflow/user/ops/min_max_observer_op.cpp index 3d7f186c378..84b68b8cdec 100644 --- a/oneflow/user/ops/min_max_observer_op.cpp +++ b/oneflow/user/ops/min_max_observer_op.cpp @@ -23,16 +23,16 @@ namespace oneflow { if (ctx->Attr("quantization_formula") == "google") { if (ctx->Attr("per_layer_quantization") == true) { - *ctx->OutputShape("scale", 0) = Shape({1}); - *ctx->OutputShape("zero_point", 0) = Shape({1}); + *ctx->MutOutputShape("scale", 0) = Shape({1}); + *ctx->MutOutputShape("zero_point", 0) = Shape({1}); } else { // NOTE(Liang Depeng): For now per-channel quantization only support axis 0 - *ctx->OutputShape("scale", 0) = Shape({in_shape.At(0)}); - *ctx->OutputShape("zero_point", 0) = Shape({in_shape.At(0)}); + *ctx->MutOutputShape("scale", 0) = Shape({in_shape.At(0)}); + *ctx->MutOutputShape("zero_point", 0) = Shape({in_shape.At(0)}); } } else { // quantization_formula == "cambricon" - *ctx->OutputShape("scale", 0) = Shape({1}); - *ctx->OutputShape("zero_point", 0) = Shape({1}); + *ctx->MutOutputShape("scale", 0) = Shape({1}); + *ctx->MutOutputShape("zero_point", 0) = Shape({1}); } return Maybe::Ok(); } diff --git a/oneflow/user/ops/mish_op.cpp b/oneflow/user/ops/mish_op.cpp index bee4ebb18a8..58dd37fdda5 100644 --- a/oneflow/user/ops/mish_op.cpp +++ b/oneflow/user/ops/mish_op.cpp @@ -19,7 +19,7 @@ limitations under the License. namespace oneflow { /* static */ Maybe MishOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); return Maybe::Ok(); } @@ -43,7 +43,7 @@ namespace oneflow { /* static */ Maybe MishGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); + Shape* dx_shape = ctx->MutOutputShape("dx", 0); CHECK_OR_RETURN(dy_shape == x_shape); *dx_shape = dy_shape; return Maybe::Ok(); diff --git a/oneflow/user/ops/model_update_ops.cpp b/oneflow/user/ops/model_update_ops.cpp index 0bcaf045247..cbfbf4b78bf 100644 --- a/oneflow/user/ops/model_update_ops.cpp +++ b/oneflow/user/ops/model_update_ops.cpp @@ -752,7 +752,7 @@ Maybe InferLarsUpdateDataType(user_op::InferContext* ctx) { /* static */ Maybe AdamBiasCorrectionFactorOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("train_step", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("train_step", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/moving_average_min_max_observer_op.cpp b/oneflow/user/ops/moving_average_min_max_observer_op.cpp index 434865f2d59..4e374c2de45 100644 --- a/oneflow/user/ops/moving_average_min_max_observer_op.cpp +++ b/oneflow/user/ops/moving_average_min_max_observer_op.cpp @@ -31,8 +31,8 @@ namespace oneflow { CHECK_OR_RETURN(current_train_step.NumAxes() == 1 && current_train_step.At(0) == 1); - *ctx->OutputShape("scale", 0) = Shape({1}); - *ctx->OutputShape("zero_point", 0) = Shape({1}); + *ctx->MutOutputShape("scale", 0) = Shape({1}); + *ctx->MutOutputShape("zero_point", 0) = Shape({1}); return Maybe::Ok(); } diff --git a/oneflow/user/ops/multi_reduce_ops.cpp b/oneflow/user/ops/multi_reduce_ops.cpp index 58ceca4ff10..89022884317 100644 --- a/oneflow/user/ops/multi_reduce_ops.cpp +++ b/oneflow/user/ops/multi_reduce_ops.cpp @@ -23,7 +23,7 @@ namespace { Maybe InferMultiReduceOpShape(user_op::InferContext* ctx) { CHECK_GT_OR_RETURN(ctx->input_size("x"), 0) << ctx->op_name() << "must have at least 1 input"; - *ctx->OutputShape("y", 0) = Shape({}); + *ctx->MutOutputShape("y", 0) = Shape({}); return Maybe::Ok(); } @@ -67,13 +67,13 @@ Maybe InferLocalMultiReduceOpLogicalShape(user_op::InferContext* ctx) { for (int64_t i = 0; i < rank_mesh->NumAxes(); ++i) { if (any_nd_sbp.sbp_parallel(i).has_split_parallel()) { split_num *= rank_mesh->At(i); } } - *ctx->OutputShape("y", 0) = Shape({split_num}); + *ctx->MutOutputShape("y", 0) = Shape({split_num}); return Maybe::Ok(); } Maybe InferLocalMultiReduceOpPhysicalShape(user_op::InferContext* ctx) { CHECK_GT_OR_RETURN(ctx->input_size("x"), 0) << ctx->op_name() << "must have at least 1 input"; - *ctx->OutputShape("y", 0) = Shape({1}); + *ctx->MutOutputShape("y", 0) = Shape({1}); return Maybe::Ok(); } diff --git a/oneflow/user/ops/narrow_op.cpp b/oneflow/user/ops/narrow_op.cpp index a8569c6784e..275041ad1a5 100644 --- a/oneflow/user/ops/narrow_op.cpp +++ b/oneflow/user/ops/narrow_op.cpp @@ -83,7 +83,7 @@ namespace oneflow { const int64_t ndim = dy_shape.NumAxes(); CHECK_EQ_OR_RETURN(like_shape.NumAxes(), ndim); - *ctx->OutputShape("dx", 0) = like_shape; + *ctx->MutOutputShape("dx", 0) = like_shape; return Maybe::Ok(); } diff --git a/oneflow/user/ops/nccl_logical_2d_sbp_ops.cpp b/oneflow/user/ops/nccl_logical_2d_sbp_ops.cpp index f8bf37f2771..13c39cd301e 100644 --- a/oneflow/user/ops/nccl_logical_2d_sbp_ops.cpp +++ b/oneflow/user/ops/nccl_logical_2d_sbp_ops.cpp @@ -23,7 +23,7 @@ namespace oneflow { /* static */ Maybe _ncclLogical_2DSameDim0AllReduceOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } @@ -65,7 +65,7 @@ namespace oneflow { /* static */ Maybe _ncclLogical_2DSameDim1AllReduceOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } @@ -107,7 +107,7 @@ namespace oneflow { /* static */ Maybe _ncclLogical_2DSameDim0AllGatherOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } @@ -150,7 +150,7 @@ namespace oneflow { /* static */ Maybe _ncclLogical_2DSameDim0AllGatherNoncontinuousOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } @@ -195,7 +195,7 @@ _ncclLogical_2DSameDim0AllGatherNoncontinuousOp::InferDeviceAndStream( /* static */ Maybe _ncclLogical_2DSameDim0All2allOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/nccl_logical_ops.cpp b/oneflow/user/ops/nccl_logical_ops.cpp index 5f157516389..54baf57426c 100644 --- a/oneflow/user/ops/nccl_logical_ops.cpp +++ b/oneflow/user/ops/nccl_logical_ops.cpp @@ -23,7 +23,7 @@ namespace oneflow { /* static */ Maybe _ncclLogicalAllReduceOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } @@ -62,7 +62,7 @@ namespace oneflow { /* static */ Maybe _ncclLogicalReduceScatterOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } @@ -103,7 +103,7 @@ namespace oneflow { /* static */ Maybe _ncclLogicalAllGatherOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } @@ -143,7 +143,7 @@ namespace oneflow { /* static */ Maybe _ncclLogicalAllGatherNoncontinuousOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } @@ -185,7 +185,7 @@ namespace oneflow { /* static */ Maybe _ncclLogicalReduceScatterNoncontinuousOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } @@ -230,7 +230,7 @@ namespace oneflow { } /* static */ Maybe _ncclLogicalS2sOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } @@ -269,7 +269,7 @@ namespace oneflow { /* static */ Maybe _ncclLogicalSendRecvOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/nd_index_slice_ops.cpp b/oneflow/user/ops/nd_index_slice_ops.cpp index 2fa17d2d390..bdbae09b336 100644 --- a/oneflow/user/ops/nd_index_slice_ops.cpp +++ b/oneflow/user/ops/nd_index_slice_ops.cpp @@ -42,7 +42,7 @@ Maybe InferScatterNdTensorDesc(user_op::InferContext* ctx) { const Shape& updates_shape = ctx->InputShape("updates", 0); const Shape& params_shape = ctx->Attr("shape"); JUST(CheckScatterNdShape(params_shape, indices_shape, updates_shape)); - *ctx->OutputShape("out", 0) = params_shape; + *ctx->MutOutputShape("out", 0) = params_shape; return Maybe::Ok(); } @@ -56,7 +56,7 @@ Maybe InferScatterNdLikeTensorDesc(user_op::InferContext* ctx) { const Shape& updates_shape = ctx->InputShape("updates", 0); const Shape& like_shape = ctx->InputShape("like", 0); JUST(CheckScatterNdShape(like_shape, indices_shape, updates_shape)); - *ctx->OutputShape("out", 0) = like_shape; + *ctx->MutOutputShape("out", 0) = like_shape; return Maybe::Ok(); } @@ -70,7 +70,7 @@ Maybe InferTensorScatterNdOptTensorDesc(user_op::InferContext* ctx) { const Shape& updates_shape = ctx->InputShape("updates", 0); const Shape& indices_shape = ctx->InputShape("indices", 0); JUST(CheckScatterNdShape(params_shape, indices_shape, updates_shape)); - *ctx->OutputShape("out", 0) = params_shape; + *ctx->MutOutputShape("out", 0) = params_shape; return Maybe::Ok(); } @@ -122,7 +122,7 @@ Maybe GetTensorScatterNdOptSbpSignatures(user_op::SbpContext* ctx) { FOR_RANGE(int64_t, i, index_ndims, params_shape.NumAxes()) { out_shape_vec.emplace_back(params_shape.At(i)); } - *ctx->OutputShape("out", 0) = Shape(out_shape_vec); + *ctx->MutOutputShape("out", 0) = Shape(out_shape_vec); return Maybe::Ok(); } diff --git a/oneflow/user/ops/nms_op.cpp b/oneflow/user/ops/nms_op.cpp index 1d9c0e29537..ea4d0a4c0f5 100644 --- a/oneflow/user/ops/nms_op.cpp +++ b/oneflow/user/ops/nms_op.cpp @@ -21,7 +21,7 @@ namespace oneflow { namespace { Maybe InferNmsTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = Shape({ctx->InputShape("in", 0).At(0)}); + *ctx->MutOutputShape("out", 0) = Shape({ctx->InputShape("in", 0).At(0)}); return Maybe::Ok(); } diff --git a/oneflow/user/ops/nvtx_range_op.cpp b/oneflow/user/ops/nvtx_range_op.cpp index 0f2bd54b2e6..c8d3509bc0f 100644 --- a/oneflow/user/ops/nvtx_range_op.cpp +++ b/oneflow/user/ops/nvtx_range_op.cpp @@ -22,7 +22,7 @@ namespace oneflow { #ifdef WITH_CUDA /* static */ Maybe NvtxStartOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } @@ -49,7 +49,7 @@ namespace oneflow { } /* static */ Maybe NvtxEndOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/one_embedding_ops.cpp b/oneflow/user/ops/one_embedding_ops.cpp index 99938d2d03d..5ac91d991d5 100644 --- a/oneflow/user/ops/one_embedding_ops.cpp +++ b/oneflow/user/ops/one_embedding_ops.cpp @@ -30,7 +30,7 @@ namespace oneflow { DimVector out_dim_vec = ids_shape.dim_vec(); const int64_t embedding_size = ctx->Attr("embedding_size"); out_dim_vec.push_back(embedding_size); - *ctx->OutputShape("embeddings", 0) = Shape(out_dim_vec); + *ctx->MutOutputShape("embeddings", 0) = Shape(out_dim_vec); return Maybe::Ok(); } @@ -116,7 +116,7 @@ REGISTER_USER_OP_GRAD("embedding_lookup_placeholder") CHECK_EQ_OR_RETURN(unique_ids_shape, table_ids_shape) << "table_ids shape must equal to ids shape"; CHECK_EQ_OR_RETURN(num_unique_ids_shape.elem_cnt(), 1); - *ctx->OutputShape("context", 0) = num_unique_ids_shape; + *ctx->MutOutputShape("context", 0) = num_unique_ids_shape; return Maybe::Ok(); } @@ -155,19 +155,19 @@ REGISTER_USER_OP_GRAD("embedding_lookup_placeholder") const bool use_dynamic_memory_allocation = embedding::UseDynamicMemoryAllocation(); if (ctx->has_output("embeddings", 0)) { if (use_dynamic_memory_allocation) { - *ctx->OutputShape("embeddings", 0) = Shape({1}); + *ctx->MutOutputShape("embeddings", 0) = Shape({1}); } else { DimVector embeddings_dim_vec = unique_ids_shape.dim_vec(); embeddings_dim_vec.push_back(embedding_size); - *ctx->OutputShape("embeddings", 0) = Shape(embeddings_dim_vec); + *ctx->MutOutputShape("embeddings", 0) = Shape(embeddings_dim_vec); } } if (use_dynamic_memory_allocation) { - *ctx->OutputShape("unique_values", 0) = Shape({1}); + *ctx->MutOutputShape("unique_values", 0) = Shape({1}); } else { DimVector unique_values_dim_vec = unique_ids_shape.dim_vec(); unique_values_dim_vec.push_back(line_size); - *ctx->OutputShape("unique_values", 0) = Shape(unique_values_dim_vec); + *ctx->MutOutputShape("unique_values", 0) = Shape(unique_values_dim_vec); } return Maybe::Ok(); @@ -318,7 +318,7 @@ Maybe GetEmbeddingUpdateSbp(user_op::SbpContext* ctx) { CHECK_NE_OR_RETURN(line_size, 0) << "should set attr line_size"; CHECK_EQ_OR_RETURN(line_size, embedding_size) << "get " << line_size << " " << embedding_size; const Shape& unique_embeddings_shape = ctx->InputShape("unique_embeddings", 0); - *ctx->OutputShape("updated_unique_embeddings", 0) = unique_embeddings_shape; + *ctx->MutOutputShape("updated_unique_embeddings", 0) = unique_embeddings_shape; return Maybe::Ok(); } @@ -346,7 +346,7 @@ Maybe GetEmbeddingUpdateSbp(user_op::SbpContext* ctx) { CHECK_NE_OR_RETURN(line_size, 0) << "should set attr line_size"; CHECK_EQ_OR_RETURN(line_size, embedding_size * 2) << "get " << line_size << " " << embedding_size; const Shape& unique_embeddings_shape = ctx->InputShape("unique_embeddings", 0); - *ctx->OutputShape("updated_unique_embeddings", 0) = unique_embeddings_shape; + *ctx->MutOutputShape("updated_unique_embeddings", 0) = unique_embeddings_shape; return Maybe::Ok(); } @@ -374,7 +374,7 @@ Maybe GetEmbeddingUpdateSbp(user_op::SbpContext* ctx) { CHECK_NE_OR_RETURN(line_size, 0) << "should set attr line_size"; CHECK_EQ_OR_RETURN(line_size, embedding_size * 3) << "get " << line_size << " " << embedding_size; const Shape& unique_embeddings_shape = ctx->InputShape("unique_embeddings", 0); - *ctx->OutputShape("updated_unique_embeddings", 0) = unique_embeddings_shape; + *ctx->MutOutputShape("updated_unique_embeddings", 0) = unique_embeddings_shape; return Maybe::Ok(); } @@ -402,7 +402,7 @@ Maybe GetEmbeddingUpdateSbp(user_op::SbpContext* ctx) { CHECK_NE_OR_RETURN(line_size, 0) << "should set attr line_size"; CHECK_EQ_OR_RETURN(line_size, embedding_size * 2) << "get " << line_size << " " << embedding_size; const Shape& unique_embeddings_shape = ctx->InputShape("unique_embeddings", 0); - *ctx->OutputShape("updated_unique_embeddings", 0) = unique_embeddings_shape; + *ctx->MutOutputShape("updated_unique_embeddings", 0) = unique_embeddings_shape; return Maybe::Ok(); } @@ -430,7 +430,7 @@ Maybe GetEmbeddingUpdateSbp(user_op::SbpContext* ctx) { CHECK_NE_OR_RETURN(line_size, 0) << "should set attr line_size"; CHECK_EQ_OR_RETURN(line_size, embedding_size * 3) << "get " << line_size << " " << embedding_size; const Shape& unique_embeddings_shape = ctx->InputShape("unique_embeddings", 0); - *ctx->OutputShape("updated_unique_embeddings", 0) = unique_embeddings_shape; + *ctx->MutOutputShape("updated_unique_embeddings", 0) = unique_embeddings_shape; return Maybe::Ok(); } diff --git a/oneflow/user/ops/ones_like_op.cpp b/oneflow/user/ops/ones_like_op.cpp index c64eefc2a0f..74f49c31590 100644 --- a/oneflow/user/ops/ones_like_op.cpp +++ b/oneflow/user/ops/ones_like_op.cpp @@ -33,8 +33,8 @@ namespace oneflow { return Maybe::Ok(); } /*static*/ Maybe OnesLikeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("like", 0); - *ctx->OutputStride("out", 0) = ctx->InputStride("like", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("like", 0); + *ctx->MutOutputStride("out", 0) = ctx->InputStride("like", 0); return Maybe::Ok(); } /*static*/ Maybe OnesLikeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { diff --git a/oneflow/user/ops/p2p_comm_op.cpp b/oneflow/user/ops/p2p_comm_op.cpp index 0c6998bdb87..1103106a736 100644 --- a/oneflow/user/ops/p2p_comm_op.cpp +++ b/oneflow/user/ops/p2p_comm_op.cpp @@ -48,7 +48,7 @@ Maybe> GetRecvOutputDeivce(user_op::DeviceAndStreamInferContext* /*static*/ Maybe RecvOp::GetSbp(user_op::SbpContext* ctx) { UNIMPLEMENTED_THEN_RETURN(); } /*static*/ Maybe RecvOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->Attr("shape"); + *ctx->MutOutputShape("out", 0) = ctx->Attr("shape"); return Maybe::Ok(); } /*static*/ Maybe RecvOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { diff --git a/oneflow/user/ops/pad_op.cpp b/oneflow/user/ops/pad_op.cpp index d1d020ed355..ce545d812f5 100644 --- a/oneflow/user/ops/pad_op.cpp +++ b/oneflow/user/ops/pad_op.cpp @@ -40,7 +40,7 @@ namespace oneflow { FOR_RANGE(int64_t, i, 0, x_shape.NumAxes()) { y_dim_vec[i] = x_shape.At(i) + padding_before[i] + padding_after[i]; } - *ctx->OutputShape("y", 0) = Shape(y_dim_vec); + *ctx->MutOutputShape("y", 0) = Shape(y_dim_vec); return Maybe::Ok(); } /*static*/ Maybe PadOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { diff --git a/oneflow/user/ops/padding_ops.cpp b/oneflow/user/ops/padding_ops.cpp index 41ef1da54ea..400f846e9fd 100644 --- a/oneflow/user/ops/padding_ops.cpp +++ b/oneflow/user/ops/padding_ops.cpp @@ -74,7 +74,7 @@ Maybe GetOpGradSbpSignature(user_op::SbpContext* ctx) { y_dim_vec[h_idx] = h_x + padding[2] + padding[3]; y_dim_vec[w_idx] = w_x + padding[0] + padding[1]; - *ctx->OutputShape("y", 0) = Shape(y_dim_vec); + *ctx->MutOutputShape("y", 0) = Shape(y_dim_vec); return Maybe::Ok(); } /*static*/ Maybe ReflectionPad2DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { @@ -113,7 +113,7 @@ Maybe GetOpGradSbpSignature(user_op::SbpContext* ctx) { dx_dim_vec[h_idx] = h_dy - padding[2] - padding[3]; dx_dim_vec[w_idx] = w_dy - padding[0] - padding[1]; - *ctx->OutputShape("dx", 0) = Shape(dx_dim_vec); + *ctx->MutOutputShape("dx", 0) = Shape(dx_dim_vec); return Maybe::Ok(); } /*static*/ Maybe ReflectionPad2DGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { @@ -162,7 +162,7 @@ REGISTER_USER_OP_GRAD("reflection_pad2d") y_dim_vec[h_idx] = h_x + padding[2] + padding[3]; y_dim_vec[w_idx] = w_x + padding[0] + padding[1]; - *ctx->OutputShape("y", 0) = Shape(y_dim_vec); + *ctx->MutOutputShape("y", 0) = Shape(y_dim_vec); return Maybe::Ok(); } /*static*/ Maybe ReplicationPad2DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { @@ -201,7 +201,7 @@ REGISTER_USER_OP_GRAD("reflection_pad2d") dx_dim_vec[h_idx] = h_dy - padding[2] - padding[3]; dx_dim_vec[w_idx] = w_dy - padding[0] - padding[1]; - *ctx->OutputShape("dx", 0) = Shape(dx_dim_vec); + *ctx->MutOutputShape("dx", 0) = Shape(dx_dim_vec); return Maybe::Ok(); } /*static*/ Maybe ReplicationPad2DGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { diff --git a/oneflow/user/ops/parallel_cast_op.cpp b/oneflow/user/ops/parallel_cast_op.cpp index 9d25b9504de..e24f264cd8a 100644 --- a/oneflow/user/ops/parallel_cast_op.cpp +++ b/oneflow/user/ops/parallel_cast_op.cpp @@ -23,7 +23,7 @@ namespace oneflow { return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); } /*static*/ Maybe ParallelCastOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/partial_fc_sample_op.cpp b/oneflow/user/ops/partial_fc_sample_op.cpp index 1798e91fe6d..9ca056933aa 100644 --- a/oneflow/user/ops/partial_fc_sample_op.cpp +++ b/oneflow/user/ops/partial_fc_sample_op.cpp @@ -111,11 +111,11 @@ namespace oneflow { } /*static*/ Maybe DistributedPartialFcSampleDisableBoxingOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { - *ctx->OutputShape("boxing_disabled_sampled_weight_diff", 0) = + *ctx->MutOutputShape("boxing_disabled_sampled_weight_diff", 0) = ctx->InputShape("sampled_weight_diff", 0); *ctx->OutputIsDynamic("boxing_disabled_sampled_weight_diff", 0) = ctx->InputIsDynamic("sampled_weight_diff", 0); - *ctx->OutputShape("boxing_disabled_sampled_label", 0) = ctx->InputShape("sampled_label", 0); + *ctx->MutOutputShape("boxing_disabled_sampled_label", 0) = ctx->InputShape("sampled_label", 0); *ctx->OutputIsDynamic("boxing_disabled_sampled_label", 0) = ctx->InputIsDynamic("sampled_label", 0); return Maybe::Ok(); diff --git a/oneflow/user/ops/prelu_op.cpp b/oneflow/user/ops/prelu_op.cpp index 6cd352ba5ba..1b19189f328 100644 --- a/oneflow/user/ops/prelu_op.cpp +++ b/oneflow/user/ops/prelu_op.cpp @@ -40,7 +40,7 @@ namespace oneflow { } /*static*/ Maybe PreluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); - Shape* y_shape = ctx->OutputShape("y", 0); + Shape* y_shape = ctx->MutOutputShape("y", 0); const Shape& alpha_shape = ctx->InputShape("alpha", 0); CHECK_EQ_OR_RETURN(alpha_shape.NumAxes(), 1); *y_shape = x_shape; @@ -91,8 +91,8 @@ namespace oneflow { /*static*/ Maybe PreluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); - Shape* alpha_diff_shape = ctx->OutputShape("alpha_diff", 0); + Shape* dx_shape = ctx->MutOutputShape("dx", 0); + Shape* alpha_diff_shape = ctx->MutOutputShape("alpha_diff", 0); const Shape& alpha_shape = ctx->InputShape("alpha", 0); CHECK_EQ_OR_RETURN(alpha_shape.NumAxes(), 1); CHECK_OR_RETURN((alpha_shape.At(0) == x_shape.At(1)) || (alpha_shape.At(0) == 1)); diff --git a/oneflow/user/ops/quantization_op.cpp b/oneflow/user/ops/quantization_op.cpp index 2396a1a1685..759b65472bf 100644 --- a/oneflow/user/ops/quantization_op.cpp +++ b/oneflow/user/ops/quantization_op.cpp @@ -68,7 +68,7 @@ namespace oneflow { CHECK_EQ_OR_RETURN(zero_point_shape.elem_cnt(), in_shape.At(0)); } - *ctx->OutputShape("out", 0) = in_shape; + *ctx->MutOutputShape("out", 0) = in_shape; return Maybe::Ok(); } /*static*/ Maybe QuantizationOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { diff --git a/oneflow/user/ops/randperm_op.cpp b/oneflow/user/ops/randperm_op.cpp index 956902154ae..7075f37327d 100644 --- a/oneflow/user/ops/randperm_op.cpp +++ b/oneflow/user/ops/randperm_op.cpp @@ -27,7 +27,7 @@ namespace oneflow { } /*static*/ Maybe RandpermOp::GetSbp(user_op::SbpContext* ctx) { return Maybe::Ok(); } /*static*/ Maybe RandpermOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - Shape* out_shape = ctx->OutputShape("out", 0); + Shape* out_shape = ctx->MutOutputShape("out", 0); int32_t n = ctx->Attr("n"); CHECK_GE_OR_RETURN(n, 0) << Error::RuntimeError() << "Trying to create tensor with negative dimension " << n << ":" @@ -45,7 +45,7 @@ namespace oneflow { GetTensorSliceView4ParallelId(parallel_hierarchy, nd_sbp, logical_shape, parallel_id); const Shape& physical_shape = tensor_slice_view.shape(); - *ctx->OutputShape("out", 0) = physical_shape; + *ctx->MutOutputShape("out", 0) = physical_shape; return Maybe::Ok(); } diff --git a/oneflow/user/ops/reduce_ops.cpp b/oneflow/user/ops/reduce_ops.cpp index 5ac0a70038c..fbfcff77d8f 100644 --- a/oneflow/user/ops/reduce_ops.cpp +++ b/oneflow/user/ops/reduce_ops.cpp @@ -23,8 +23,8 @@ namespace oneflow { Maybe InferTensorDescFn(user_op::InferContext* ctx) { const Shape& input_shape = ctx->InputShape("input_tensor", 0); const auto& reduce_axes = ctx->Attr>("axis"); - Shape* output_shape = ctx->OutputShape("output_tensor", 0); - Stride* output_stride = ctx->OutputStride("output_tensor", 0); + Shape* output_shape = ctx->MutOutputShape("output_tensor", 0); + Stride* output_stride = ctx->MutOutputStride("output_tensor", 0); // For 0-dim Tensor if (reduce_axes.empty()) { *output_shape = input_shape; diff --git a/oneflow/user/ops/relu_op.cpp b/oneflow/user/ops/relu_op.cpp index 38e4f58328a..6b87f2fd4c0 100644 --- a/oneflow/user/ops/relu_op.cpp +++ b/oneflow/user/ops/relu_op.cpp @@ -27,7 +27,7 @@ namespace oneflow { } /*static*/ Maybe ReluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& in_shape = ctx->InputShape("x", 0); - Shape* out_shape = ctx->OutputShape("y", 0); + Shape* out_shape = ctx->MutOutputShape("y", 0); *out_shape = in_shape; return Maybe::Ok(); } @@ -53,7 +53,7 @@ namespace oneflow { /*static*/ Maybe ReluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& y_shape = ctx->InputShape("y", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); + Shape* dx_shape = ctx->MutOutputShape("dx", 0); CHECK_OR_RETURN(dy_shape == y_shape) << Error::RuntimeError() << "Tensors y and dy must have the same shape"; *dx_shape = dy_shape; diff --git a/oneflow/user/ops/repeat_op.cpp b/oneflow/user/ops/repeat_op.cpp index 60b281854dc..2f00322b3a2 100644 --- a/oneflow/user/ops/repeat_op.cpp +++ b/oneflow/user/ops/repeat_op.cpp @@ -31,7 +31,7 @@ namespace oneflow { return Maybe::Ok(); } /*static*/ Maybe RepeatOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/reshape_like_op.cpp b/oneflow/user/ops/reshape_like_op.cpp index 7b11d6de6f0..e40cab51ebd 100644 --- a/oneflow/user/ops/reshape_like_op.cpp +++ b/oneflow/user/ops/reshape_like_op.cpp @@ -44,7 +44,7 @@ namespace oneflow { << "The element number of the in tensor must be equal to the element number of the " "like tensor, " << "but got " << in_shape.elem_cnt() << " and " << like_shape.elem_cnt(); - *ctx->OutputShape("out", 0) = like_shape; + *ctx->MutOutputShape("out", 0) = like_shape; return Maybe::Ok(); } /*static*/ Maybe ReshapeLikeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { diff --git a/oneflow/user/ops/roi_align_op.cpp b/oneflow/user/ops/roi_align_op.cpp index c2a45e6eedc..090c29674a5 100644 --- a/oneflow/user/ops/roi_align_op.cpp +++ b/oneflow/user/ops/roi_align_op.cpp @@ -37,7 +37,7 @@ namespace oneflow { CHECK_EQ(rois_shape.NumAxes(), 2); CHECK_EQ(rois_shape.At(1), 5); // y: (R, C, pool_h, pool_w) - *ctx->OutputShape("y", 0) = Shape({rois_shape.At(0), x_shape.At(1), pooled_h, pooled_w}); + *ctx->MutOutputShape("y", 0) = Shape({rois_shape.At(0), x_shape.At(1), pooled_h, pooled_w}); return Maybe::Ok(); } /*static*/ Maybe RoiAlignOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { @@ -81,7 +81,7 @@ namespace oneflow { // y: (R, C, pool_h, pool_w) const Shape& y_shape = Shape({rois_shape.At(0), x_like_shape.At(1), pooled_h, pooled_w}); CHECK_EQ_OR_RETURN(y_shape, dy_shape); - *ctx->OutputShape("dx", 0) = x_like_shape; + *ctx->MutOutputShape("dx", 0) = x_like_shape; return Maybe::Ok(); } /*static*/ Maybe RoiAlignGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { diff --git a/oneflow/user/ops/roll_op.cpp b/oneflow/user/ops/roll_op.cpp index b07077d814b..01fd2742c3b 100644 --- a/oneflow/user/ops/roll_op.cpp +++ b/oneflow/user/ops/roll_op.cpp @@ -45,7 +45,7 @@ namespace oneflow { } /*static*/ Maybe RollOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& in_shape = ctx->InputShape("in", 0); - *ctx->OutputShape("out", 0) = in_shape; + *ctx->MutOutputShape("out", 0) = in_shape; return Maybe::Ok(); } /*static*/ Maybe RollOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { diff --git a/oneflow/user/ops/same_padding_op.cpp b/oneflow/user/ops/same_padding_op.cpp index 267faf5fecf..40ca7ccd3f9 100644 --- a/oneflow/user/ops/same_padding_op.cpp +++ b/oneflow/user/ops/same_padding_op.cpp @@ -108,7 +108,7 @@ namespace oneflow { return Maybe::Ok(); } /*static*/ Maybe SamePaddingGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("dx", 0) = ctx->InputShape("x_like", 0); + *ctx->MutOutputShape("dx", 0) = ctx->InputShape("x_like", 0); *ctx->OutputIsDynamic("dx", 0) = ctx->InputIsDynamic("x_like", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/scalar_logical_op.cpp b/oneflow/user/ops/scalar_logical_op.cpp index 8c0786c2804..a242b67f924 100644 --- a/oneflow/user/ops/scalar_logical_op.cpp +++ b/oneflow/user/ops/scalar_logical_op.cpp @@ -27,7 +27,7 @@ namespace oneflow { return Maybe::Ok(); \ } \ /*static*/ Maybe name##Op::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); \ + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); \ *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); \ return Maybe::Ok(); \ } \ diff --git a/oneflow/user/ops/scalar_math_op.cpp b/oneflow/user/ops/scalar_math_op.cpp index 3627acde3cf..6712023f60c 100644 --- a/oneflow/user/ops/scalar_math_op.cpp +++ b/oneflow/user/ops/scalar_math_op.cpp @@ -42,7 +42,7 @@ Maybe GetSbp4ScalarMul(user_op::SbpContext* ctx) { #define IMPLEMENT_SCALAR_MATH_OP_FUNCS(op_name, get_sbp_fn) \ /*static*/ Maybe op_name##Op::GetSbp(user_op::SbpContext* ctx) { return get_sbp_fn(ctx); } \ /*static*/ Maybe op_name##Op::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); \ + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); \ *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); \ return Maybe::Ok(); \ } \ @@ -71,7 +71,7 @@ IMPLEMENT_SCALAR_MATH_OP_FUNCS(ScalarReversePow, GetSbp4ScalarMath) return Maybe::Ok(); } /*static*/ Maybe ScalarPowGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("dx", 0) = ctx->InputShape("x", 0); + *ctx->MutOutputShape("dx", 0) = ctx->InputShape("x", 0); return Maybe::Ok(); } /*static*/ Maybe ScalarPowGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { @@ -92,7 +92,7 @@ IMPLEMENT_SCALAR_MATH_OP_FUNCS(ScalarReversePow, GetSbp4ScalarMath) return Maybe::Ok(); } /*static*/ Maybe ScalarReversePowGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("dx", 0) = ctx->InputShape("x", 0); + *ctx->MutOutputShape("dx", 0) = ctx->InputShape("x", 0); return Maybe::Ok(); } /*static*/ Maybe ScalarReversePowGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { diff --git a/oneflow/user/ops/search_sorted_op.cpp b/oneflow/user/ops/search_sorted_op.cpp index 368114c17ec..1a96a0a9ccb 100644 --- a/oneflow/user/ops/search_sorted_op.cpp +++ b/oneflow/user/ops/search_sorted_op.cpp @@ -19,7 +19,7 @@ limitations under the License. namespace oneflow { /* static */ Maybe SearchSortedOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("values", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("values", 0); return Maybe::Ok(); } @@ -54,7 +54,7 @@ namespace oneflow { } /* static */ Maybe SearchSortedScalarOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = Shape({}); + *ctx->MutOutputShape("out", 0) = Shape({}); return Maybe::Ok(); } diff --git a/oneflow/user/ops/selu_op.cpp b/oneflow/user/ops/selu_op.cpp index e23a95c8526..cb0de53192e 100644 --- a/oneflow/user/ops/selu_op.cpp +++ b/oneflow/user/ops/selu_op.cpp @@ -26,7 +26,7 @@ namespace oneflow { return Maybe::Ok(); } /*static*/ Maybe SeluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); return Maybe::Ok(); } /*static*/ Maybe SeluOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { @@ -51,7 +51,7 @@ namespace oneflow { /*static*/ Maybe SeluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); + Shape* dx_shape = ctx->MutOutputShape("dx", 0); CHECK_OR_RETURN(dy_shape == x_shape) << Error::RuntimeError() << "Tensors dy and x must be the same shape"; *dx_shape = dy_shape; diff --git a/oneflow/user/ops/silu_op.cpp b/oneflow/user/ops/silu_op.cpp index 8e35ae69ab1..cc459d2a605 100644 --- a/oneflow/user/ops/silu_op.cpp +++ b/oneflow/user/ops/silu_op.cpp @@ -26,7 +26,7 @@ namespace oneflow { return Maybe::Ok(); } /*static*/ Maybe SiluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); return Maybe::Ok(); } /*static*/ Maybe SiluOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { @@ -51,7 +51,7 @@ namespace oneflow { /*static*/ Maybe SiluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); + Shape* dx_shape = ctx->MutOutputShape("dx", 0); CHECK_OR_RETURN(dy_shape == x_shape) << Error::RuntimeError() << "The size of dy " << dy_shape << " must match the size of x " << x_shape; *dx_shape = dy_shape; diff --git a/oneflow/user/ops/slice_op.cpp b/oneflow/user/ops/slice_op.cpp index 3ae88200258..c0b7bea6caa 100644 --- a/oneflow/user/ops/slice_op.cpp +++ b/oneflow/user/ops/slice_op.cpp @@ -170,7 +170,7 @@ bool IsFullSlice(int64_t start, int64_t stop, int64_t step, int64_t size) { const int64_t diff = stop - start - 1; dim_vec[i] = diff / step + 1; } - *ctx->OutputShape("y", 0) = Shape(dim_vec); + *ctx->MutOutputShape("y", 0) = Shape(dim_vec); return Maybe::Ok(); } /*static*/ Maybe SliceOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { @@ -198,7 +198,7 @@ bool IsFullSlice(int64_t start, int64_t stop, int64_t step, int64_t size) { const int64_t parallel_id = ctx->parallel_ctx().parallel_id(); const TensorSliceView& slice_view = GetTensorSliceView4ParallelId(parallel_hierarchy, y_nd_sbp, logical_shape, parallel_id); - *ctx->OutputShape("y", 0) = Shape(slice_view.shape()); + *ctx->MutOutputShape("y", 0) = Shape(slice_view.shape()); return Maybe::Ok(); } /*static*/ Maybe SliceOp::InferDataType(user_op::InferContext* ctx) { @@ -253,7 +253,7 @@ bool IsFullSlice(int64_t start, int64_t stop, int64_t step, int64_t size) { << Error::RuntimeError() << "The size of step list must be equal to the dimension of ref tensor, " << "but got " << step_vec.size() << " and " << ndim; - *ctx->OutputShape("dx", 0) = like_shape; + *ctx->MutOutputShape("dx", 0) = like_shape; return Maybe::Ok(); } /*static*/ Maybe SliceGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { diff --git a/oneflow/user/ops/softmax_cross_entropy_op.cpp b/oneflow/user/ops/softmax_cross_entropy_op.cpp index 1b31f895407..f193e333c5d 100644 --- a/oneflow/user/ops/softmax_cross_entropy_op.cpp +++ b/oneflow/user/ops/softmax_cross_entropy_op.cpp @@ -51,7 +51,7 @@ namespace oneflow { FOR_RANGE(int64_t, i, 0, num_out_axes) { out_dim_vector.emplace_back(prediction_desc.shape().At(i)); } - *ctx->OutputShape("prob", 0) = ctx->InputShape("prediction", 0); + *ctx->MutOutputShape("prob", 0) = ctx->InputShape("prediction", 0); *ctx->OutputIsDynamic("prob", 0) = ctx->InputIsDynamic("prediction", 0); user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); *out_desc->mut_is_dynamic() = prediction_desc.is_dynamic(); @@ -118,7 +118,7 @@ namespace oneflow { CHECK_EQ_OR_RETURN(label_desc.shape(), prob_desc.shape()) << Error::RuntimeError() << "The size of label " << label_desc.shape() << " must match the size of prob " << prob_desc.shape(); - *ctx->OutputShape("prediction_diff", 0) = ctx->InputShape("prob", 0); + *ctx->MutOutputShape("prediction_diff", 0) = ctx->InputShape("prob", 0); *ctx->OutputIsDynamic("prediction_diff", 0) = ctx->InputIsDynamic("prob", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/softmax_op.cpp b/oneflow/user/ops/softmax_op.cpp index a726d561073..4dfc29ad88d 100644 --- a/oneflow/user/ops/softmax_op.cpp +++ b/oneflow/user/ops/softmax_op.cpp @@ -29,7 +29,7 @@ namespace oneflow { return Maybe::Ok(); } /*static*/ Maybe SoftmaxOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); return Maybe::Ok(); } /*static*/ Maybe SoftmaxOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { @@ -54,7 +54,7 @@ namespace oneflow { /*static*/ Maybe SoftmaxGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& y_shape = ctx->InputShape("y", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); + Shape* dx_shape = ctx->MutOutputShape("dx", 0); CHECK_OR_RETURN(dy_shape == y_shape) << Error::RuntimeError() << "The size of dy " << dy_shape << " must match the size of y " << y_shape; *dx_shape = dy_shape; diff --git a/oneflow/user/ops/softplus_op.cpp b/oneflow/user/ops/softplus_op.cpp index 2a772b661c0..18ec0cfc439 100644 --- a/oneflow/user/ops/softplus_op.cpp +++ b/oneflow/user/ops/softplus_op.cpp @@ -19,7 +19,7 @@ limitations under the License. namespace oneflow { /* static */ Maybe SoftplusOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); return Maybe::Ok(); } @@ -43,7 +43,7 @@ namespace oneflow { /* static */ Maybe SoftplusGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); + Shape* dx_shape = ctx->MutOutputShape("dx", 0); CHECK_OR_RETURN(dy_shape == x_shape) << Error::RuntimeError() << "The size of dy " << dy_shape << " must match the size of x " << x_shape; *dx_shape = dy_shape; diff --git a/oneflow/user/ops/softshrink_op.cpp b/oneflow/user/ops/softshrink_op.cpp index 95ec290270b..3bed51333d4 100644 --- a/oneflow/user/ops/softshrink_op.cpp +++ b/oneflow/user/ops/softshrink_op.cpp @@ -19,7 +19,7 @@ limitations under the License. namespace oneflow { /* static */ Maybe SoftShrinkOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); return Maybe::Ok(); } @@ -43,7 +43,7 @@ namespace oneflow { /* static */ Maybe SoftShrinkGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& y_shape = ctx->InputShape("y", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); + Shape* dx_shape = ctx->MutOutputShape("dx", 0); CHECK_OR_RETURN(dy_shape == y_shape) << Error::RuntimeError() << "The size of dy " << dy_shape << " must match the size of y " << y_shape; *dx_shape = dy_shape; diff --git a/oneflow/user/ops/softsign_op.cpp b/oneflow/user/ops/softsign_op.cpp index 61e45f781e6..2b474b67f19 100644 --- a/oneflow/user/ops/softsign_op.cpp +++ b/oneflow/user/ops/softsign_op.cpp @@ -26,7 +26,7 @@ namespace oneflow { return Maybe::Ok(); } /*static*/ Maybe SoftsignOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); return Maybe::Ok(); } /*static*/ Maybe SoftsignOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { @@ -51,7 +51,7 @@ namespace oneflow { /*static*/ Maybe SoftsignGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); + Shape* dx_shape = ctx->MutOutputShape("dx", 0); CHECK_OR_RETURN(dy_shape == x_shape) << Error::RuntimeError() << "The size of dy " << dy_shape << " must match the size of x " << x_shape; *dx_shape = dy_shape; diff --git a/oneflow/user/ops/sort_op.cpp b/oneflow/user/ops/sort_op.cpp index f2dd5e6f89b..5c3add243b3 100644 --- a/oneflow/user/ops/sort_op.cpp +++ b/oneflow/user/ops/sort_op.cpp @@ -28,7 +28,7 @@ namespace oneflow { return Maybe::Ok(); } /*static*/ Maybe SortOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); return Maybe::Ok(); } /*static*/ Maybe SortOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { diff --git a/oneflow/user/ops/sparse_cross_entropy_op.cpp b/oneflow/user/ops/sparse_cross_entropy_op.cpp index b661910fe8c..adce0aa9b7f 100644 --- a/oneflow/user/ops/sparse_cross_entropy_op.cpp +++ b/oneflow/user/ops/sparse_cross_entropy_op.cpp @@ -62,7 +62,7 @@ Maybe InferGradTensorDescFn(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(dy_desc.shape(), label_desc.shape()) << Error::RuntimeError() << "The size of dy " << dy_desc.shape() << " must match the size of label " << label_desc.shape(); - *ctx->OutputShape("prediction_diff", 0) = prediction_desc.shape(); + *ctx->MutOutputShape("prediction_diff", 0) = prediction_desc.shape(); *ctx->OutputIsDynamic("prediction_diff", 0) = prediction_desc.is_dynamic(); return Maybe::Ok(); } diff --git a/oneflow/user/ops/sparse_softmax_cross_entropy_op.cpp b/oneflow/user/ops/sparse_softmax_cross_entropy_op.cpp index 0d77af3f218..7e02cb9fd23 100644 --- a/oneflow/user/ops/sparse_softmax_cross_entropy_op.cpp +++ b/oneflow/user/ops/sparse_softmax_cross_entropy_op.cpp @@ -43,7 +43,7 @@ Maybe InferTensorDescFn(user_op::InferContext* ctx) { } *ctx->OutputIsDynamic("prob", 0) = prediction_desc.is_dynamic(); // 'prob' is just for compute prediction's grad, prob's grad will be ignored - *ctx->OutputShape("prob", 0) = prediction_desc.shape(); + *ctx->MutOutputShape("prob", 0) = prediction_desc.shape(); user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); *out_desc->mut_is_dynamic() = prediction_desc.is_dynamic(); *out_desc->mut_shape() = label_desc.shape(); @@ -75,7 +75,7 @@ Maybe InferGradTensorDescFn(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(dy_desc.shape(), label_desc.shape()) << Error::RuntimeError() << "The size of dy " << dy_desc.shape() << " must match the size of label " << label_desc.shape(); - *ctx->OutputShape("prediction_diff", 0) = prob_desc.shape(); + *ctx->MutOutputShape("prediction_diff", 0) = prob_desc.shape(); *ctx->OutputIsDynamic("prediction_diff", 0) = prob_desc.is_dynamic(); return Maybe::Ok(); } diff --git a/oneflow/user/ops/squeeze_op.cpp b/oneflow/user/ops/squeeze_op.cpp index d6c9cb111a4..5fe2422a6a8 100644 --- a/oneflow/user/ops/squeeze_op.cpp +++ b/oneflow/user/ops/squeeze_op.cpp @@ -63,7 +63,7 @@ Maybe CheckAndLabelAxesToSqueezeMinusOne(const AxisVector& axes, DimVector } /*static*/ Maybe SqueezeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& in_shape = ctx->InputShape("in", 0); - Shape* out_shape = ctx->OutputShape("out", 0); + Shape* out_shape = ctx->MutOutputShape("out", 0); AxisVector fixed_axes_vec; JUST(TransformNegativeAxesToPositive(ctx->Attr>("axes"), in_shape.NumAxes(), &fixed_axes_vec)); diff --git a/oneflow/user/ops/ssp_variable_proxy_op.cpp b/oneflow/user/ops/ssp_variable_proxy_op.cpp index 9a5a31262a7..00299abcd86 100644 --- a/oneflow/user/ops/ssp_variable_proxy_op.cpp +++ b/oneflow/user/ops/ssp_variable_proxy_op.cpp @@ -31,8 +31,8 @@ namespace oneflow { } /*static*/ Maybe SspVariableProxyOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& var_shape = ctx->InputShape("var", 0); - *ctx->OutputShape("ref", 0) = var_shape; - *ctx->OutputShape("value", 0) = var_shape; + *ctx->MutOutputShape("ref", 0) = var_shape; + *ctx->MutOutputShape("value", 0) = var_shape; return Maybe::Ok(); } /*static*/ Maybe SspVariableProxyOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { diff --git a/oneflow/user/ops/tf_pool_op.cpp b/oneflow/user/ops/tf_pool_op.cpp index 39afc8478b8..73a6ab3380e 100644 --- a/oneflow/user/ops/tf_pool_op.cpp +++ b/oneflow/user/ops/tf_pool_op.cpp @@ -51,7 +51,7 @@ TensorDescInferFn MakeFwTensorDescInferFn(const int32_t dim) { } Maybe BwTensorDescInferFn(user_op::InferContext* ctx) { - *ctx->OutputShape("dx", 0) = ctx->InputShape("x", 0); + *ctx->MutOutputShape("dx", 0) = ctx->InputShape("x", 0); *ctx->OutputIsDynamic("dx", 0) = ctx->InputIsDynamic("x", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/tf_prelu_op.cpp b/oneflow/user/ops/tf_prelu_op.cpp index b4880e201e7..f183d82e607 100644 --- a/oneflow/user/ops/tf_prelu_op.cpp +++ b/oneflow/user/ops/tf_prelu_op.cpp @@ -102,7 +102,7 @@ namespace oneflow { CHECK_EQ_OR_RETURN(dy_desc.data_type(), x_desc.data_type()); *dx_desc->mut_shape() = x_desc.shape(); *dx_desc->mut_is_dynamic() = x_desc.is_dynamic(); - *ctx->OutputShape("alpha_diff", 0) = alpha_desc.shape(); + *ctx->MutOutputShape("alpha_diff", 0) = alpha_desc.shape(); *ctx->OutputIsDynamic("alpha_diff", 0) = alpha_desc.is_dynamic(); return Maybe::Ok(); } diff --git a/oneflow/user/ops/threshold_op.cpp b/oneflow/user/ops/threshold_op.cpp index 3cf10ab9dae..f2ad58f111f 100644 --- a/oneflow/user/ops/threshold_op.cpp +++ b/oneflow/user/ops/threshold_op.cpp @@ -19,7 +19,7 @@ limitations under the License. namespace oneflow { /* static */ Maybe ThresholdOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); return Maybe::Ok(); } @@ -43,7 +43,7 @@ namespace oneflow { /* static */ Maybe ThresholdGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); + Shape* dx_shape = ctx->MutOutputShape("dx", 0); CHECK_OR_RETURN(dy_shape == x_shape); *dx_shape = dy_shape; return Maybe::Ok(); diff --git a/oneflow/user/ops/to_contiguous_op.cpp b/oneflow/user/ops/to_contiguous_op.cpp index 95a80c3e1b6..09ce23959f8 100644 --- a/oneflow/user/ops/to_contiguous_op.cpp +++ b/oneflow/user/ops/to_contiguous_op.cpp @@ -24,8 +24,8 @@ namespace oneflow { } /*static*/ Maybe ToContiguousOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputStride("out", 0) = Stride(in_desc.shape()); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputStride("out", 0) = Stride(in_desc.shape()); return Maybe::Ok(); } /*static*/ Maybe ToContiguousOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { diff --git a/oneflow/user/ops/top_k_op.cpp b/oneflow/user/ops/top_k_op.cpp index 0bcf295d5bd..c41051e8252 100644 --- a/oneflow/user/ops/top_k_op.cpp +++ b/oneflow/user/ops/top_k_op.cpp @@ -29,7 +29,7 @@ namespace oneflow { } /*static*/ Maybe TopKOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& in_shape = ctx->InputShape("in", 0); - Shape* out_shape = ctx->OutputShape("out", 0); + Shape* out_shape = ctx->MutOutputShape("out", 0); *out_shape = in_shape; out_shape->Set(in_shape.NumAxes() - 1, std::min(ctx->Attr("k"), static_cast(in_shape.dim_vec().back()))); diff --git a/oneflow/user/ops/tuple_identity_op.cpp b/oneflow/user/ops/tuple_identity_op.cpp index dd98f2fef74..7e2631989d0 100644 --- a/oneflow/user/ops/tuple_identity_op.cpp +++ b/oneflow/user/ops/tuple_identity_op.cpp @@ -26,7 +26,7 @@ namespace oneflow { const int64_t in_size = ctx->input_size("in"); CHECK_EQ_OR_RETURN(ctx->output_size("out"), in_size); for (int64_t i = 0; i < in_size; ++i) { - *ctx->OutputShape("out", i) = ctx->InputShape("in", i); + *ctx->MutOutputShape("out", i) = ctx->InputShape("in", i); *ctx->IsDynamic4ArgNameAndIndex("out", i) = ctx->InputIsDynamic("in", i); } return Maybe::Ok(); diff --git a/oneflow/user/ops/two_stage_reduce_ops.cpp b/oneflow/user/ops/two_stage_reduce_ops.cpp index 9fbb79e1da0..0c65508c8b6 100644 --- a/oneflow/user/ops/two_stage_reduce_ops.cpp +++ b/oneflow/user/ops/two_stage_reduce_ops.cpp @@ -33,7 +33,7 @@ Maybe InferReduceDeviceStageLogicalTensorDescFn(user_op::InferContext* ctx const Shape& input_shape = ctx->InputShape("in", 0); const auto& axis = ctx->Attr>("axis"); const int64_t num_axes = input_shape.NumAxes(); - Shape* output_shape = ctx->OutputShape("out", 0); + Shape* output_shape = ctx->MutOutputShape("out", 0); if (axis.empty()) { *output_shape = Shape::Ones(num_axes); } else { @@ -63,8 +63,8 @@ Maybe InferReduceDeviceStageLogicalTensorDescFn(user_op::InferContext* ctx *output_shape = Shape(dim_vec); } - *ctx->OutputShape("mask", 0) = input_shape; - *ctx->OutputShape("count", 0) = *output_shape; + *ctx->MutOutputShape("mask", 0) = input_shape; + *ctx->MutOutputShape("count", 0) = *output_shape; return Maybe::Ok(); } @@ -72,7 +72,7 @@ Maybe InferReduceDeviceStageLogicalTensorDescFn(user_op::InferContext* ctx Maybe InferReduceDeviceStagePhysicalTensorDescFn(user_op::InferContext* ctx) { const Shape& input_shape = ctx->InputShape("in", 0); const auto& axis = ctx->Attr>("axis"); - Shape* output_shape = ctx->OutputShape("out", 0); + Shape* output_shape = ctx->MutOutputShape("out", 0); if (axis.empty()) { *output_shape = Shape::Ones(input_shape.NumAxes()); } else { @@ -81,8 +81,8 @@ Maybe InferReduceDeviceStagePhysicalTensorDescFn(user_op::InferContext* ct *output_shape = reduced_shape; } - *ctx->OutputShape("mask", 0) = input_shape; - *ctx->OutputShape("count", 0) = *output_shape; + *ctx->MutOutputShape("mask", 0) = input_shape; + *ctx->MutOutputShape("count", 0) = *output_shape; return Maybe::Ok(); } @@ -96,7 +96,7 @@ Maybe InferReduceDeviceStageGradDtypeFn(user_op::InferContext* ctx) { Maybe InferReduceDeviceStageGradTensorDescFn(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(ctx->InputShape("out_diff", 0), ctx->InputShape("count", 0)); - *ctx->OutputShape("in_diff", 0) = ctx->InputShape("mask", 0); + *ctx->MutOutputShape("in_diff", 0) = ctx->InputShape("mask", 0); return Maybe::Ok(); } @@ -114,7 +114,7 @@ Maybe InferReduceGlobalStageTensorDescFn(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(input_shape, device_count_shape); const auto& axis = ctx->Attr>("axis"); bool keepdims = ctx->Attr("keepdims"); - Shape* output_shape = ctx->OutputShape("out", 0); + Shape* output_shape = ctx->MutOutputShape("out", 0); if (axis.empty()) { if (keepdims) { *output_shape = Shape::Ones(input_shape.NumAxes()); @@ -131,7 +131,7 @@ Maybe InferReduceGlobalStageTensorDescFn(user_op::InferContext* ctx) { } } - *ctx->OutputShape("mask", 0) = input_shape; + *ctx->MutOutputShape("mask", 0) = input_shape; return Maybe::Ok(); } @@ -149,7 +149,7 @@ Maybe InferReduceGlobalStageGradTensorDescFn(user_op::InferContext* ctx) { const Shape& mask_shape = ctx->InputShape("mask", 0); const Shape& device_count_shape = ctx->InputShape("device_count", 0); CHECK_EQ_OR_RETURN(device_count_shape, mask_shape); - *ctx->OutputShape("in_diff", 0) = mask_shape; + *ctx->MutOutputShape("in_diff", 0) = mask_shape; return Maybe::Ok(); } diff --git a/oneflow/user/ops/unfold_fold_op.cpp b/oneflow/user/ops/unfold_fold_op.cpp index 0560561604c..ce851cce8c7 100644 --- a/oneflow/user/ops/unfold_fold_op.cpp +++ b/oneflow/user/ops/unfold_fold_op.cpp @@ -58,7 +58,7 @@ Maybe UnfoldTensorDescInferFn(user_op::InferContext* ctx) { * std::accumulate(kernel_size.begin(), kernel_size.end(), 1, std::multiplies()); y_shape.at(2) = std::accumulate(dhw_shape.begin(), dhw_shape.end(), 1, std::multiplies()); - *ctx->OutputShape("y", 0) = Shape(y_shape); + *ctx->MutOutputShape("y", 0) = Shape(y_shape); return Maybe::Ok(); } @@ -118,7 +118,7 @@ Maybe FoldTensorDescInferFn(user_op::InferContext* ctx) { y_shape.at(2) = output_size[0]; y_shape.at(3) = output_size[1]; - *ctx->OutputShape("y", 0) = Shape(y_shape); + *ctx->MutOutputShape("y", 0) = Shape(y_shape); return Maybe::Ok(); } diff --git a/oneflow/user/ops/unfold_tensor_op.cpp b/oneflow/user/ops/unfold_tensor_op.cpp index 52d1c068e6b..03c24c7bc29 100644 --- a/oneflow/user/ops/unfold_tensor_op.cpp +++ b/oneflow/user/ops/unfold_tensor_op.cpp @@ -57,7 +57,7 @@ namespace oneflow { out_shape.at(d) = in_size_at_d; } } - *ctx->OutputShape("y", 0) = Shape(out_shape); + *ctx->MutOutputShape("y", 0) = Shape(out_shape); return Maybe::Ok(); } /*static*/ Maybe UnfoldTensorOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { diff --git a/oneflow/user/ops/unsorted_segment_sum_op.cpp b/oneflow/user/ops/unsorted_segment_sum_op.cpp index 5df5e81e451..76d03477f23 100644 --- a/oneflow/user/ops/unsorted_segment_sum_op.cpp +++ b/oneflow/user/ops/unsorted_segment_sum_op.cpp @@ -52,7 +52,7 @@ namespace oneflow { const Shape& data_shape = ctx->InputShape("data", 0); const int64_t axis = ctx->Attr("axis"); const int64_t num_segments = ctx->Attr("num_segments"); - Shape* out_shape = ctx->OutputShape("out", 0); + Shape* out_shape = ctx->MutOutputShape("out", 0); const Shape& segment_ids_shape = ctx->InputShape("segment_ids", 0); DimVector dim_vec; @@ -163,7 +163,7 @@ REGISTER_USER_OP_GRAD("unsorted_segment_sum") FOR_RANGE(int64_t, i, axis + 1, like_shape.NumAxes()) { CHECK_EQ_OR_RETURN(like_shape.At(i), data_shape.At(i + segment_ids_shape.NumAxes() - 1)); } - *ctx->OutputShape("out", 0) = ctx->InputShape("like", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("like", 0); *ctx->IsDynamic4ArgNameAndIndex("out", 0) = ctx->InputIsDynamic("like", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/upsample_op.cpp b/oneflow/user/ops/upsample_op.cpp index e1d05c1b097..2edea6f8b12 100644 --- a/oneflow/user/ops/upsample_op.cpp +++ b/oneflow/user/ops/upsample_op.cpp @@ -244,7 +244,7 @@ namespace oneflow { } /*static*/ Maybe UpsampleLinear1DGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); + Shape* dx_shape = ctx->MutOutputShape("dx", 0); CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" && dy_shape.NumAxes() == 3) << "upsample_linear_1d_grad only supports NCH"; @@ -269,7 +269,7 @@ namespace oneflow { } /*static*/ Maybe UpsampleNearest1DGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); + Shape* dx_shape = ctx->MutOutputShape("dx", 0); CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" && dy_shape.NumAxes() == 3) << "upsample_nearest_1d_grad only supports NCH"; @@ -295,7 +295,7 @@ namespace oneflow { } /*static*/ Maybe UpsampleNearest2DGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); + Shape* dx_shape = ctx->MutOutputShape("dx", 0); CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" && dy_shape.NumAxes() == 4) << "upsample_nearest_2d_grad only supports NCHW"; @@ -322,7 +322,7 @@ namespace oneflow { /*static*/ Maybe UpsampleBilinear2DGradOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); + Shape* dx_shape = ctx->MutOutputShape("dx", 0); CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" && dy_shape.NumAxes() == 4) << "upsample_bilinear_2d_grad only supports NCHW"; @@ -348,7 +348,7 @@ namespace oneflow { } /*static*/ Maybe UpsampleBicubic2DGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); + Shape* dx_shape = ctx->MutOutputShape("dx", 0); CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" && dy_shape.NumAxes() == 4) << "upsample_bicubic_2d_grad only supports NCHW"; @@ -374,7 +374,7 @@ namespace oneflow { } /*static*/ Maybe UpsampleNearest3DGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); + Shape* dx_shape = ctx->MutOutputShape("dx", 0); CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" && dy_shape.NumAxes() == 5) << "upsample_nearest_3d_grad only supports NCDHW"; @@ -401,7 +401,7 @@ namespace oneflow { /*static*/ Maybe UpsampleTrilinear3DGradOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); + Shape* dx_shape = ctx->MutOutputShape("dx", 0); CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" && dy_shape.NumAxes() == 5) << "upsample_trilinear_3d_grad only supports NCDHW"; diff --git a/oneflow/user/ops/util_ops.cpp b/oneflow/user/ops/util_ops.cpp index 0be4ce5f115..2b4a68a986a 100644 --- a/oneflow/user/ops/util_ops.cpp +++ b/oneflow/user/ops/util_ops.cpp @@ -19,7 +19,7 @@ limitations under the License. namespace oneflow { /* static */ Maybe IsNanOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); return Maybe::Ok(); } @@ -43,7 +43,7 @@ namespace oneflow { } /* static */ Maybe IsInfOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/variance_op.cpp b/oneflow/user/ops/variance_op.cpp index c1e578e6947..33caa475c58 100644 --- a/oneflow/user/ops/variance_op.cpp +++ b/oneflow/user/ops/variance_op.cpp @@ -27,7 +27,7 @@ Maybe VarOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const AxisVector reduce_axes_vec = {reduce_axes.begin(), reduce_axes.end()}; const Shape& reduce_shape = CreateReducedShape(input_shape, reduce_axes_vec); const bool keepdim = ctx->Attr("keepdim"); - Shape* output_shape = ctx->OutputShape("output", 0); + Shape* output_shape = ctx->MutOutputShape("output", 0); if (keepdim) { *output_shape = reduce_shape; } else { diff --git a/oneflow/user/ops/vector_matrix_product_op.cpp b/oneflow/user/ops/vector_matrix_product_op.cpp index 834ace4ab4c..6d85721cd30 100644 --- a/oneflow/user/ops/vector_matrix_product_op.cpp +++ b/oneflow/user/ops/vector_matrix_product_op.cpp @@ -26,7 +26,7 @@ Maybe InferTensorDesc4VectorMatrixProduct(user_op::InferContext* ctx) { int64_t k = a.shape().At(0); CHECK_EQ_OR_RETURN(k, b.shape().At(0)) << "Dim K should be equal to vector b's dim0. "; int64_t n = b.shape().At(1); - *ctx->OutputShape("out", 0) = Shape({n}); + *ctx->MutOutputShape("out", 0) = Shape({n}); return Maybe::Ok(); } @@ -45,7 +45,7 @@ Maybe InferTensorDesc4VectorMatrixProductGradA(user_op::InferContext* ctx) */ const user_op::TensorDesc& b = ctx->InputTensorDesc("b", 0); int64_t k = b.shape().At(0); - *ctx->OutputShape("dx", 0) = Shape({k}); + *ctx->MutOutputShape("dx", 0) = Shape({k}); return Maybe::Ok(); } @@ -58,7 +58,7 @@ Maybe InferTensorDesc4VectorMatrixProductGradB(user_op::InferContext* ctx) const user_op::TensorDesc& a = ctx->InputTensorDesc("a", 0); int64_t k = a.shape().At(0); int64_t n = dy.shape().At(0); - *ctx->OutputShape("dx", 0) = Shape({k, n}); + *ctx->MutOutputShape("dx", 0) = Shape({k, n}); return Maybe::Ok(); } diff --git a/oneflow/user/ops/where_op.cpp b/oneflow/user/ops/where_op.cpp index e49ffb19fe6..4a4baf75285 100644 --- a/oneflow/user/ops/where_op.cpp +++ b/oneflow/user/ops/where_op.cpp @@ -81,11 +81,11 @@ Maybe InferWhereTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); const Shape& y_shape = ctx->InputShape("y", 0); if (x_shape == y_shape && y_shape == cond_shape) { - *ctx->OutputShape("out", 0) = cond_shape; + *ctx->MutOutputShape("out", 0) = cond_shape; } else { Shape max_shape = *JUST(GetBroadcastShape(cond_shape, x_shape)); max_shape = *JUST(GetBroadcastShape(max_shape, y_shape)); - *ctx->OutputShape("out", 0) = max_shape; + *ctx->MutOutputShape("out", 0) = max_shape; } return Maybe::Ok(); } @@ -94,10 +94,10 @@ Maybe InferWhereXScalarTensorDesc(user_op::InferContext* ctx) { const Shape& cond_shape = ctx->InputShape("condition", 0); const Shape& y_shape = ctx->InputShape("y", 0); if (cond_shape == y_shape) { - *ctx->OutputShape("out", 0) = cond_shape; + *ctx->MutOutputShape("out", 0) = cond_shape; } else { Shape max_shape = *JUST(GetBroadcastShape(cond_shape, y_shape)); - *ctx->OutputShape("out", 0) = max_shape; + *ctx->MutOutputShape("out", 0) = max_shape; } return Maybe::Ok(); } @@ -106,16 +106,16 @@ Maybe InferWhereYScalarTensorDesc(user_op::InferContext* ctx) { const Shape& cond_shape = ctx->InputShape("condition", 0); const Shape& x_shape = ctx->InputShape("x", 0); if (cond_shape == x_shape) { - *ctx->OutputShape("out", 0) = cond_shape; + *ctx->MutOutputShape("out", 0) = cond_shape; } else { Shape max_shape = *JUST(GetBroadcastShape(cond_shape, x_shape)); - *ctx->OutputShape("out", 0) = max_shape; + *ctx->MutOutputShape("out", 0) = max_shape; } return Maybe::Ok(); } Maybe InferWhereXYScalarTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("condition", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("condition", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/zero_like_op.cpp b/oneflow/user/ops/zero_like_op.cpp index ad648779684..e301865998f 100644 --- a/oneflow/user/ops/zero_like_op.cpp +++ b/oneflow/user/ops/zero_like_op.cpp @@ -33,7 +33,7 @@ namespace oneflow { return Maybe::Ok(); } /*static*/ Maybe ZeroLikeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out", 0) = ctx->InputShape("like", 0); + *ctx->MutOutputShape("out", 0) = ctx->InputShape("like", 0); return Maybe::Ok(); } /*static*/ Maybe ZeroLikeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { From 4d7490b870f171c4882dfce1d5e2a49921a859ec Mon Sep 17 00:00:00 2001 From: clackhan Date: Thu, 21 Jul 2022 09:32:19 +0800 Subject: [PATCH 40/67] fix merge master error --- oneflow/user/ops/one_embedding_ops.cpp | 34 +++++++++++++------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/oneflow/user/ops/one_embedding_ops.cpp b/oneflow/user/ops/one_embedding_ops.cpp index e2399e07345..10e9e91e969 100644 --- a/oneflow/user/ops/one_embedding_ops.cpp +++ b/oneflow/user/ops/one_embedding_ops.cpp @@ -30,7 +30,7 @@ namespace oneflow { DimVector out_dim_vec = ids_shape.dim_vec(); const int64_t embedding_size = ctx->Attr("embedding_size"); out_dim_vec.push_back(embedding_size); - *ctx->MutOutputShape("embeddings", 0) = Shape(out_dim_vec); + *ctx->MutMutOutputShape("embeddings", 0) = Shape(out_dim_vec); return Maybe::Ok(); } @@ -116,7 +116,7 @@ REGISTER_USER_OP_GRAD("embedding_lookup_placeholder") CHECK_EQ_OR_RETURN(unique_ids_shape, table_ids_shape) << "table_ids shape must equal to ids shape"; CHECK_EQ_OR_RETURN(num_unique_ids_shape.elem_cnt(), 1); - *ctx->MutOutputShape("context", 0) = num_unique_ids_shape; + *ctx->MutMutOutputShape("context", 0) = num_unique_ids_shape; return Maybe::Ok(); } @@ -155,19 +155,19 @@ REGISTER_USER_OP_GRAD("embedding_lookup_placeholder") const bool use_dynamic_memory_allocation = embedding::UseDynamicMemoryAllocation(); if (ctx->has_output("embeddings", 0)) { if (use_dynamic_memory_allocation) { - *ctx->MutOutputShape("embeddings", 0) = Shape({1}); + *ctx->MutMutOutputShape("embeddings", 0) = Shape({1}); } else { DimVector embeddings_dim_vec = unique_ids_shape.dim_vec(); embeddings_dim_vec.push_back(embedding_size); - *ctx->MutOutputShape("embeddings", 0) = Shape(embeddings_dim_vec); + *ctx->MutMutOutputShape("embeddings", 0) = Shape(embeddings_dim_vec); } } if (use_dynamic_memory_allocation) { - *ctx->MutOutputShape("unique_values", 0) = Shape({1}); + *ctx->MutMutOutputShape("unique_values", 0) = Shape({1}); } else { DimVector unique_values_dim_vec = unique_ids_shape.dim_vec(); unique_values_dim_vec.push_back(line_size); - *ctx->MutOutputShape("unique_values", 0) = Shape(unique_values_dim_vec); + *ctx->MutMutOutputShape("unique_values", 0) = Shape(unique_values_dim_vec); } return Maybe::Ok(); @@ -318,7 +318,7 @@ Maybe GetEmbeddingUpdateSbp(user_op::SbpContext* ctx) { CHECK_NE_OR_RETURN(line_size, 0) << "should set attr line_size"; CHECK_EQ_OR_RETURN(line_size, embedding_size) << "get " << line_size << " " << embedding_size; const Shape& unique_embeddings_shape = ctx->InputShape("unique_embeddings", 0); - *ctx->MutOutputShape("updated_unique_embeddings", 0) = unique_embeddings_shape; + *ctx->MutMutOutputShape("updated_unique_embeddings", 0) = unique_embeddings_shape; return Maybe::Ok(); } @@ -346,7 +346,7 @@ Maybe GetEmbeddingUpdateSbp(user_op::SbpContext* ctx) { CHECK_NE_OR_RETURN(line_size, 0) << "should set attr line_size"; CHECK_EQ_OR_RETURN(line_size, embedding_size * 2) << "get " << line_size << " " << embedding_size; const Shape& unique_embeddings_shape = ctx->InputShape("unique_embeddings", 0); - *ctx->MutOutputShape("updated_unique_embeddings", 0) = unique_embeddings_shape; + *ctx->MutMutOutputShape("updated_unique_embeddings", 0) = unique_embeddings_shape; return Maybe::Ok(); } @@ -374,7 +374,7 @@ Maybe GetEmbeddingUpdateSbp(user_op::SbpContext* ctx) { CHECK_NE_OR_RETURN(line_size, 0) << "should set attr line_size"; CHECK_EQ_OR_RETURN(line_size, embedding_size * 3) << "get " << line_size << " " << embedding_size; const Shape& unique_embeddings_shape = ctx->InputShape("unique_embeddings", 0); - *ctx->MutOutputShape("updated_unique_embeddings", 0) = unique_embeddings_shape; + *ctx->MutMutOutputShape("updated_unique_embeddings", 0) = unique_embeddings_shape; return Maybe::Ok(); } @@ -402,7 +402,7 @@ Maybe GetEmbeddingUpdateSbp(user_op::SbpContext* ctx) { CHECK_NE_OR_RETURN(line_size, 0) << "should set attr line_size"; CHECK_EQ_OR_RETURN(line_size, embedding_size * 2) << "get " << line_size << " " << embedding_size; const Shape& unique_embeddings_shape = ctx->InputShape("unique_embeddings", 0); - *ctx->MutOutputShape("updated_unique_embeddings", 0) = unique_embeddings_shape; + *ctx->MutMutOutputShape("updated_unique_embeddings", 0) = unique_embeddings_shape; return Maybe::Ok(); } @@ -430,7 +430,7 @@ Maybe GetEmbeddingUpdateSbp(user_op::SbpContext* ctx) { CHECK_NE_OR_RETURN(line_size, 0) << "should set attr line_size"; CHECK_EQ_OR_RETURN(line_size, embedding_size * 3) << "get " << line_size << " " << embedding_size; const Shape& unique_embeddings_shape = ctx->InputShape("unique_embeddings", 0); - *ctx->MutOutputShape("updated_unique_embeddings", 0) = unique_embeddings_shape; + *ctx->MutMutOutputShape("updated_unique_embeddings", 0) = unique_embeddings_shape; return Maybe::Ok(); } @@ -462,14 +462,14 @@ Maybe GetEmbeddingUpdateSbp(user_op::SbpContext* ctx) { } /*static*/ Maybe IdShuffleCopyOutOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - *ctx->OutputShape("out_num_unique_matrix", 0) = ctx->InputShape("num_unique_matrix", 0); - *ctx->OutputShape("out_inverse_unique_partition_indices", 0) = + *ctx->MutOutputShape("out_num_unique_matrix", 0) = ctx->InputShape("num_unique_matrix", 0); + *ctx->MutOutputShape("out_inverse_unique_partition_indices", 0) = ctx->InputShape("inverse_unique_partition_indices", 0); - *ctx->OutputShape("out_cur_rank_num_unique", 0) = ctx->InputShape("cur_rank_num_unique", 0); - *ctx->OutputShape("out_cur_rank_unique_ids", 0) = ctx->InputShape("cur_rank_unique_ids", 0); - *ctx->OutputShape("out_cur_rank_unique_table_ids", 0) = + *ctx->MutOutputShape("out_cur_rank_num_unique", 0) = ctx->InputShape("cur_rank_num_unique", 0); + *ctx->MutOutputShape("out_cur_rank_unique_ids", 0) = ctx->InputShape("cur_rank_unique_ids", 0); + *ctx->MutOutputShape("out_cur_rank_unique_table_ids", 0) = ctx->InputShape("cur_rank_unique_table_ids", 0); - *ctx->OutputShape("out_cur_rank_inverse_indices", 0) = + *ctx->MutOutputShape("out_cur_rank_inverse_indices", 0) = ctx->InputShape("cur_rank_inverse_indices", 0); return Maybe::Ok(); } From d4ac72fccdbcd4e7b517291714219a26e8eb32fd Mon Sep 17 00:00:00 2001 From: clackhan Date: Thu, 21 Jul 2022 09:33:50 +0800 Subject: [PATCH 41/67] fix typo --- oneflow/user/ops/one_embedding_ops.cpp | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/oneflow/user/ops/one_embedding_ops.cpp b/oneflow/user/ops/one_embedding_ops.cpp index 10e9e91e969..1c50437e865 100644 --- a/oneflow/user/ops/one_embedding_ops.cpp +++ b/oneflow/user/ops/one_embedding_ops.cpp @@ -30,7 +30,7 @@ namespace oneflow { DimVector out_dim_vec = ids_shape.dim_vec(); const int64_t embedding_size = ctx->Attr("embedding_size"); out_dim_vec.push_back(embedding_size); - *ctx->MutMutOutputShape("embeddings", 0) = Shape(out_dim_vec); + *ctx->MutOutputShape("embeddings", 0) = Shape(out_dim_vec); return Maybe::Ok(); } @@ -116,7 +116,7 @@ REGISTER_USER_OP_GRAD("embedding_lookup_placeholder") CHECK_EQ_OR_RETURN(unique_ids_shape, table_ids_shape) << "table_ids shape must equal to ids shape"; CHECK_EQ_OR_RETURN(num_unique_ids_shape.elem_cnt(), 1); - *ctx->MutMutOutputShape("context", 0) = num_unique_ids_shape; + *ctx->MutOutputShape("context", 0) = num_unique_ids_shape; return Maybe::Ok(); } @@ -155,19 +155,19 @@ REGISTER_USER_OP_GRAD("embedding_lookup_placeholder") const bool use_dynamic_memory_allocation = embedding::UseDynamicMemoryAllocation(); if (ctx->has_output("embeddings", 0)) { if (use_dynamic_memory_allocation) { - *ctx->MutMutOutputShape("embeddings", 0) = Shape({1}); + *ctx->MutOutputShape("embeddings", 0) = Shape({1}); } else { DimVector embeddings_dim_vec = unique_ids_shape.dim_vec(); embeddings_dim_vec.push_back(embedding_size); - *ctx->MutMutOutputShape("embeddings", 0) = Shape(embeddings_dim_vec); + *ctx->MutOutputShape("embeddings", 0) = Shape(embeddings_dim_vec); } } if (use_dynamic_memory_allocation) { - *ctx->MutMutOutputShape("unique_values", 0) = Shape({1}); + *ctx->MutOutputShape("unique_values", 0) = Shape({1}); } else { DimVector unique_values_dim_vec = unique_ids_shape.dim_vec(); unique_values_dim_vec.push_back(line_size); - *ctx->MutMutOutputShape("unique_values", 0) = Shape(unique_values_dim_vec); + *ctx->MutOutputShape("unique_values", 0) = Shape(unique_values_dim_vec); } return Maybe::Ok(); @@ -318,7 +318,7 @@ Maybe GetEmbeddingUpdateSbp(user_op::SbpContext* ctx) { CHECK_NE_OR_RETURN(line_size, 0) << "should set attr line_size"; CHECK_EQ_OR_RETURN(line_size, embedding_size) << "get " << line_size << " " << embedding_size; const Shape& unique_embeddings_shape = ctx->InputShape("unique_embeddings", 0); - *ctx->MutMutOutputShape("updated_unique_embeddings", 0) = unique_embeddings_shape; + *ctx->MutOutputShape("updated_unique_embeddings", 0) = unique_embeddings_shape; return Maybe::Ok(); } @@ -346,7 +346,7 @@ Maybe GetEmbeddingUpdateSbp(user_op::SbpContext* ctx) { CHECK_NE_OR_RETURN(line_size, 0) << "should set attr line_size"; CHECK_EQ_OR_RETURN(line_size, embedding_size * 2) << "get " << line_size << " " << embedding_size; const Shape& unique_embeddings_shape = ctx->InputShape("unique_embeddings", 0); - *ctx->MutMutOutputShape("updated_unique_embeddings", 0) = unique_embeddings_shape; + *ctx->MutOutputShape("updated_unique_embeddings", 0) = unique_embeddings_shape; return Maybe::Ok(); } @@ -374,7 +374,7 @@ Maybe GetEmbeddingUpdateSbp(user_op::SbpContext* ctx) { CHECK_NE_OR_RETURN(line_size, 0) << "should set attr line_size"; CHECK_EQ_OR_RETURN(line_size, embedding_size * 3) << "get " << line_size << " " << embedding_size; const Shape& unique_embeddings_shape = ctx->InputShape("unique_embeddings", 0); - *ctx->MutMutOutputShape("updated_unique_embeddings", 0) = unique_embeddings_shape; + *ctx->MutOutputShape("updated_unique_embeddings", 0) = unique_embeddings_shape; return Maybe::Ok(); } @@ -402,7 +402,7 @@ Maybe GetEmbeddingUpdateSbp(user_op::SbpContext* ctx) { CHECK_NE_OR_RETURN(line_size, 0) << "should set attr line_size"; CHECK_EQ_OR_RETURN(line_size, embedding_size * 2) << "get " << line_size << " " << embedding_size; const Shape& unique_embeddings_shape = ctx->InputShape("unique_embeddings", 0); - *ctx->MutMutOutputShape("updated_unique_embeddings", 0) = unique_embeddings_shape; + *ctx->MutOutputShape("updated_unique_embeddings", 0) = unique_embeddings_shape; return Maybe::Ok(); } @@ -430,7 +430,7 @@ Maybe GetEmbeddingUpdateSbp(user_op::SbpContext* ctx) { CHECK_NE_OR_RETURN(line_size, 0) << "should set attr line_size"; CHECK_EQ_OR_RETURN(line_size, embedding_size * 3) << "get " << line_size << " " << embedding_size; const Shape& unique_embeddings_shape = ctx->InputShape("unique_embeddings", 0); - *ctx->MutMutOutputShape("updated_unique_embeddings", 0) = unique_embeddings_shape; + *ctx->MutOutputShape("updated_unique_embeddings", 0) = unique_embeddings_shape; return Maybe::Ok(); } From e7ac48a445b9891a1c57593f09793137522260b3 Mon Sep 17 00:00:00 2001 From: clackhan Date: Thu, 21 Jul 2022 09:47:02 +0800 Subject: [PATCH 42/67] fix static check error --- oneflow/core/eager/eager_blob_object.cpp | 1 + oneflow/core/framework/tensor.cpp | 2 +- oneflow/core/framework/tensor_impl.cpp | 2 +- oneflow/core/functional/impl/array_functor.cpp | 6 +++--- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/oneflow/core/eager/eager_blob_object.cpp b/oneflow/core/eager/eager_blob_object.cpp index 2d9e092c2c5..a83d5653d38 100644 --- a/oneflow/core/eager/eager_blob_object.cpp +++ b/oneflow/core/eager/eager_blob_object.cpp @@ -38,6 +38,7 @@ EagerBlobObject::EagerBlobObject( mem_ptr_for_allocation_compuation_pipelining_(nullptr), inited_mem_ptr_for_allocation_compuation_pipelining_(false), is_non_pod_object_placement_newed_(false), + pin_memory_(false), compute_local_dep_object_(dep_object), blob_desc_(static_cast(mut_local_tensor_meta) ? std::const_pointer_cast(mut_local_tensor_meta->shape_ptr()) diff --git a/oneflow/core/framework/tensor.cpp b/oneflow/core/framework/tensor.cpp index cef75642ba5..1a3049c8815 100644 --- a/oneflow/core/framework/tensor.cpp +++ b/oneflow/core/framework/tensor.cpp @@ -72,7 +72,7 @@ std::shared_ptr Parameter::pin_memory() const { } else { const auto& impl = std::make_shared(requires_grad, is_leaf); const auto& dep_object = NewLocalDepObject(); - impl->InitEagerBlobObject(tensor_meta, dep_object); + JUST(impl->InitEagerBlobObject(tensor_meta, dep_object)); return std::make_shared(impl); } } diff --git a/oneflow/core/framework/tensor_impl.cpp b/oneflow/core/framework/tensor_impl.cpp index bf882e53c26..4424a5e4adb 100644 --- a/oneflow/core/framework/tensor_impl.cpp +++ b/oneflow/core/framework/tensor_impl.cpp @@ -107,7 +107,7 @@ Maybe EagerLocalTensorImpl::InitEagerBlobObject( const Symbol& local_tensor_meta, const std::shared_ptr& mut_local_tensor_meta, const intrusive::shared_ptr& dep_object) { - CHECK_OR_RETURN(static_cast(local_tensor_meta->device())); + CHECK_OR_RETURN(static_cast(local_tensor_meta->device())); // NOLINT const auto& mem_case = local_tensor_meta->device()->mem_case(); if (tensor_storage_) { diff --git a/oneflow/core/functional/impl/array_functor.cpp b/oneflow/core/functional/impl/array_functor.cpp index 28d53b44de6..ccf867236e4 100644 --- a/oneflow/core/functional/impl/array_functor.cpp +++ b/oneflow/core/functional/impl/array_functor.cpp @@ -1225,9 +1225,9 @@ class InplaceToContiguousFunctor { std::shared_ptr final_tensor_impl = std::make_shared(JUST(input->tensor_storage()), input->requires_grad(), input->is_leaf()); - final_tensor_impl->set_retain_grad(input->retain_grad()); - final_tensor_impl->InitEagerBlobObject(new_tensor_meta, - JUST(blob_object->compute_local_dep_object())); + JUST(final_tensor_impl->set_retain_grad(input->retain_grad())); + JUST(final_tensor_impl->InitEagerBlobObject(new_tensor_meta, + JUST(blob_object->compute_local_dep_object()))); JUST(JUST(input->AsLocalTensor())->set_impl(final_tensor_impl)); // assign contiguous tensor data From 8008a8dc17e3b57fbbd588680a115a72941e5b04 Mon Sep 17 00:00:00 2001 From: clackhan Date: Thu, 21 Jul 2022 14:07:52 +0800 Subject: [PATCH 43/67] define_mut_output_dtype_and_mut_output_is_dynamic_in_infer_ctx --- oneflow/core/framework/infer_util.cpp | 4 +- oneflow/core/framework/infer_util.h | 12 ++-- oneflow/core/framework/op_expr.cpp | 28 +++++++-- oneflow/core/kernel/user_kernel.cpp | 28 +++++++-- oneflow/core/operator/user_op.cpp | 32 +++++++--- oneflow/ir/oneflow-extension/extension.cpp | 4 +- oneflow/user/kernels/arg_where_kernel.cpp | 2 +- ...ttention_query_mul_key_and_value_kernel.cu | 2 +- oneflow/user/kernels/stateful_opkernel.cpp | 58 ++++++++++++++----- oneflow/user/ops/acc_op.cpp | 4 +- oneflow/user/ops/adaptive_pool_op.cpp | 6 +- oneflow/user/ops/affine_grid_op.cpp | 4 +- oneflow/user/ops/arange_op.cpp | 2 +- oneflow/user/ops/arg_sort_op.cpp | 2 +- oneflow/user/ops/argmax_op.cpp | 2 +- oneflow/user/ops/as_strided_op.cpp | 4 +- oneflow/user/ops/avg_pool_op.cpp | 4 +- oneflow/user/ops/bias_add_op.cpp | 4 +- oneflow/user/ops/binary_cross_entropy_op.cpp | 4 +- .../binary_cross_entropy_with_logits_op.cpp | 4 +- ...oss_entropy_with_logits_reduce_mean_op.cpp | 4 +- oneflow/user/ops/broadcast_div_grad_op.cpp | 4 +- oneflow/user/ops/broadcast_like_op.cpp | 2 +- oneflow/user/ops/broadcast_pow_grad_op.cpp | 8 +-- oneflow/user/ops/buffer_op.cpp | 4 +- oneflow/user/ops/cast_like_op.cpp | 2 +- oneflow/user/ops/cast_to_static_shape_op.cpp | 2 +- oneflow/user/ops/cast_to_tick_op.cpp | 2 +- .../ops/categorical_ordinal_encode_op.cpp | 2 +- oneflow/user/ops/celu_op.cpp | 4 +- oneflow/user/ops/clip_by_value_op.cpp | 4 +- oneflow/user/ops/combined_margin_loss_op.cpp | 10 ++-- oneflow/user/ops/constant_op.cpp | 2 +- oneflow/user/ops/conv_op.cpp | 10 ++-- oneflow/user/ops/copy_op.cpp | 4 +- oneflow/user/ops/ctc_loss_op.cpp | 10 ++-- oneflow/user/ops/cublas_fused_mlp_grad_op.cpp | 6 +- oneflow/user/ops/cum_ops.cpp | 6 +- oneflow/user/ops/data_shuffle_op.cpp | 28 ++++----- oneflow/user/ops/deconv_op.cpp | 2 +- oneflow/user/ops/diag_op.cpp | 4 +- oneflow/user/ops/diagonal_op.cpp | 4 +- oneflow/user/ops/dim_scatter_ops.cpp | 4 +- oneflow/user/ops/distributions/normal_op.cpp | 2 +- .../user/ops/distributions/uniform_int_op.cpp | 2 +- oneflow/user/ops/distributions/uniform_op.cpp | 2 +- oneflow/user/ops/dot_op.cpp | 2 +- oneflow/user/ops/dropout_op.cpp | 12 ++-- oneflow/user/ops/eager_b_to_s_op.cpp | 2 +- oneflow/user/ops/eager_nccl_ops.cpp | 16 ++--- oneflow/user/ops/eager_p_to_b_op.cpp | 2 +- oneflow/user/ops/eager_p_to_s_op.cpp | 2 +- oneflow/user/ops/eager_s_to_b_op.cpp | 2 +- oneflow/user/ops/eager_s_to_p_op.cpp | 2 +- oneflow/user/ops/eager_s_to_s_op.cpp | 2 +- .../user/ops/eager_symmetric_s_to_p_op.cpp | 2 +- oneflow/user/ops/elu_op.cpp | 4 +- oneflow/user/ops/embedding_op.cpp | 6 +- oneflow/user/ops/empty_op.cpp | 2 +- oneflow/user/ops/erfinv_op.cpp | 2 +- oneflow/user/ops/expand_dims_op.cpp | 2 +- oneflow/user/ops/expand_op.cpp | 4 +- oneflow/user/ops/eye_op.cpp | 2 +- oneflow/user/ops/fake_quantization_op.cpp | 2 +- oneflow/user/ops/fill_op.cpp | 4 +- oneflow/user/ops/flatten_op.cpp | 2 +- oneflow/user/ops/flip_op.cpp | 2 +- oneflow/user/ops/fused_bias_add_op.cpp | 12 ++-- .../fused_cross_feature_interaction_op.cpp | 20 +++---- .../ops/fused_dot_feature_interaction_op.cpp | 8 +-- oneflow/user/ops/fused_gru_cell_op.cpp | 16 ++--- oneflow/user/ops/fused_lstm_cell_op.cpp | 12 ++-- .../user/ops/fused_relu_dropout_grad_op.cpp | 2 +- .../fused_scale_mask_softmax_dropout_op.cpp | 8 +-- .../user/ops/fused_scale_mask_softmax_op.cpp | 4 +- ...fused_scale_tril_softmax_mask_scale_op.cpp | 8 +-- ..._attention_query_mul_key_and_value_ops.cpp | 6 +- oneflow/user/ops/gelu_op.cpp | 4 +- ...te_random_batch_permutation_indices_op.cpp | 2 +- oneflow/user/ops/grid_sample_op.cpp | 6 +- oneflow/user/ops/hardshrink_op.cpp | 4 +- oneflow/user/ops/hardsigmoid_op.cpp | 4 +- oneflow/user/ops/hardswish_op.cpp | 4 +- oneflow/user/ops/hardtanh_op.cpp | 4 +- .../ops/hierarchical_parallel_cast_op.cpp | 8 +-- oneflow/user/ops/identity_op.cpp | 4 +- .../user/ops/image_object_preprocess_ops.cpp | 28 ++++----- oneflow/user/ops/image_preprocess_ops.cpp | 2 +- oneflow/user/ops/kl_div_op.cpp | 4 +- .../user/ops/l1_l2_regularize_gradient_op.cpp | 4 +- oneflow/user/ops/l2_normalize_op.cpp | 6 +- oneflow/user/ops/leaky_relu_op.cpp | 4 +- oneflow/user/ops/log_softmax_op.cpp | 4 +- oneflow/user/ops/logical_not_op.cpp | 2 +- oneflow/user/ops/masked_fill_op.cpp | 2 +- .../user/ops/math_binary_broadcast_ops.cpp | 14 ++--- oneflow/user/ops/matmul_op.cpp | 4 +- oneflow/user/ops/matrix_vector_product_op.cpp | 4 +- oneflow/user/ops/max_pool_op.cpp | 4 +- oneflow/user/ops/median_op.cpp | 2 +- oneflow/user/ops/median_with_indices_op.cpp | 4 +- oneflow/user/ops/min_max_observer_op.cpp | 4 +- oneflow/user/ops/mish_op.cpp | 4 +- oneflow/user/ops/model_update_ops.cpp | 2 +- .../moving_average_min_max_observer_op.cpp | 4 +- oneflow/user/ops/multi_reduce_ops.cpp | 2 +- oneflow/user/ops/narrow_op.cpp | 2 +- oneflow/user/ops/nccl_logical_2d_sbp_ops.cpp | 20 +++---- oneflow/user/ops/nccl_logical_ops.cpp | 28 ++++----- oneflow/user/ops/nd_index_slice_ops.cpp | 8 +-- oneflow/user/ops/nll_op.cpp | 6 +- oneflow/user/ops/nms_op.cpp | 2 +- oneflow/user/ops/nvtx_range_op.cpp | 8 +-- ...frecord_image_classification_reader_op.cpp | 4 +- oneflow/user/ops/ofrecord_reader_op.cpp | 2 +- oneflow/user/ops/one_embedding_ops.cpp | 30 +++++----- oneflow/user/ops/onerec_reader_op.cpp | 2 +- oneflow/user/ops/ones_like_op.cpp | 2 +- oneflow/user/ops/p2p_comm_op.cpp | 2 +- oneflow/user/ops/pack_op.cpp | 2 +- oneflow/user/ops/pad_op.cpp | 2 +- oneflow/user/ops/padding_ops.cpp | 8 +-- oneflow/user/ops/parallel_cast_op.cpp | 4 +- oneflow/user/ops/partial_fc_sample_op.cpp | 14 ++--- oneflow/user/ops/prelu_op.cpp | 6 +- oneflow/user/ops/quantization_op.cpp | 2 +- oneflow/user/ops/randperm_op.cpp | 2 +- oneflow/user/ops/reduce_like_ops.cpp | 2 +- oneflow/user/ops/reduce_ops.cpp | 4 +- oneflow/user/ops/relu_op.cpp | 4 +- oneflow/user/ops/repeat_interleave_op.cpp | 2 +- oneflow/user/ops/repeat_op.cpp | 4 +- oneflow/user/ops/reshape_like_op.cpp | 2 +- oneflow/user/ops/reshape_op.cpp | 2 +- oneflow/user/ops/roc_auc_score_op.cpp | 2 +- oneflow/user/ops/roi_align_op.cpp | 4 +- oneflow/user/ops/roll_op.cpp | 2 +- oneflow/user/ops/same_padding_op.cpp | 6 +- oneflow/user/ops/scalar_logical_op.cpp | 4 +- oneflow/user/ops/scalar_math_op.cpp | 8 +-- oneflow/user/ops/search_sorted_op.cpp | 8 +-- oneflow/user/ops/selu_op.cpp | 4 +- oneflow/user/ops/sigmoid_cross_entropy_op.cpp | 4 +- oneflow/user/ops/silu_op.cpp | 4 +- oneflow/user/ops/slice_op.cpp | 4 +- oneflow/user/ops/smooth_l1_loss_op.cpp | 4 +- oneflow/user/ops/softmax_cross_entropy_op.cpp | 8 +-- oneflow/user/ops/softmax_op.cpp | 4 +- oneflow/user/ops/softplus_op.cpp | 4 +- oneflow/user/ops/softshrink_op.cpp | 4 +- oneflow/user/ops/softsign_op.cpp | 4 +- oneflow/user/ops/sort_op.cpp | 2 +- oneflow/user/ops/sparse_cross_entropy_op.cpp | 4 +- .../ops/sparse_softmax_cross_entropy_op.cpp | 10 ++-- oneflow/user/ops/sqrt_square_sum_op.cpp | 2 +- oneflow/user/ops/square_sum_op.cpp | 2 +- oneflow/user/ops/squeeze_op.cpp | 2 +- oneflow/user/ops/ssp_variable_proxy_op.cpp | 4 +- oneflow/user/ops/tf_pool_op.cpp | 6 +- oneflow/user/ops/tf_prelu_op.cpp | 8 +-- oneflow/user/ops/threshold_op.cpp | 4 +- oneflow/user/ops/to_contiguous_op.cpp | 2 +- oneflow/user/ops/top_k_op.cpp | 2 +- oneflow/user/ops/transpose_ops.cpp | 2 +- oneflow/user/ops/tuple_identity_op.cpp | 6 +- oneflow/user/ops/two_stage_reduce_ops.cpp | 14 ++--- oneflow/user/ops/unfold_fold_op.cpp | 4 +- oneflow/user/ops/unfold_tensor_op.cpp | 4 +- oneflow/user/ops/unsorted_segment_sum_op.cpp | 6 +- oneflow/user/ops/upsample_op.cpp | 28 ++++----- oneflow/user/ops/util_ops.cpp | 4 +- oneflow/user/ops/variance_op.cpp | 2 +- oneflow/user/ops/vector_matrix_product_op.cpp | 4 +- oneflow/user/ops/where_op.cpp | 12 ++-- oneflow/user/ops/zero_like_op.cpp | 2 +- 175 files changed, 567 insertions(+), 483 deletions(-) diff --git a/oneflow/core/framework/infer_util.cpp b/oneflow/core/framework/infer_util.cpp index 599f6a9070d..0e39ecdd561 100644 --- a/oneflow/core/framework/infer_util.cpp +++ b/oneflow/core/framework/infer_util.cpp @@ -39,7 +39,7 @@ Maybe TensorDescInferFnUtil::Unchanged(InferContext* ctx) { } for (size_t i = 0; i < ctx->outputs().size(); ++i) { const std::pair& output_arg = ctx->outputs().at(i); - *ctx->OutputIsDynamic(output_arg.first, output_arg.second) = first_tensor_desc->is_dynamic(); + *ctx->MutOutputIsDynamic(output_arg.first, output_arg.second) = first_tensor_desc->is_dynamic(); *ctx->OutputShape(output_arg.first, output_arg.second) = first_tensor_desc->shape(); } return Maybe::Ok(); @@ -58,7 +58,7 @@ Maybe TensorDescInferFnUtil::UnchangedDataType(InferContext* ctx) { } for (size_t i = 0; i < ctx->outputs().size(); ++i) { const std::pair& output_arg = ctx->outputs().at(i); - *ctx->OutputDType(output_arg.first, output_arg.second) = first_tensor_desc->data_type(); + *ctx->MutOutputDType(output_arg.first, output_arg.second) = first_tensor_desc->data_type(); } return Maybe::Ok(); } diff --git a/oneflow/core/framework/infer_util.h b/oneflow/core/framework/infer_util.h index 5b32ea31844..42cd5bffb27 100644 --- a/oneflow/core/framework/infer_util.h +++ b/oneflow/core/framework/infer_util.h @@ -49,8 +49,10 @@ class InferContext { virtual Stride* OutputStride(const std::string&, int32_t) = 0; virtual Stride* Stride4ArgNameAndIndex(const std::string&, int32_t) = 0; virtual const DataType& InputDType(const std::string&, int32_t) const = 0; - virtual DataType* OutputDType(const std::string&, int32_t) = 0; - virtual DataType* Dtype4ArgNameAndIndex(const std::string&, int32_t) = 0; + virtual const DataType& OutputDType(const std::string&, int32_t) const = 0; + virtual DataType* MutOutputDType(const std::string&, int32_t) = 0; + virtual const DataType& Dtype4ArgNameAndIndex(const std::string&, int32_t) const = 0; + virtual DataType* MutDtype4ArgNameAndIndex(const std::string&, int32_t) = 0; virtual const std::vector>& inputs() const = 0; virtual const std::vector>& outputs() const = 0; virtual const std::string& input(const std::string& arg_name, int32_t index) const = 0; @@ -80,8 +82,10 @@ class InferContext { virtual const NdSbp& NdSbp4ArgNameAndIndex(const std::string&, int32_t) const = 0; virtual bool InputIsDynamic(const std::string&, int32_t) const = 0; - virtual bool* OutputIsDynamic(const std::string&, int32_t) = 0; - virtual bool* IsDynamic4ArgNameAndIndex(const std::string&, int32_t) = 0; + virtual bool OutputIsDynamic(const std::string&, int32_t) const = 0; + virtual bool* MutOutputIsDynamic(const std::string&, int32_t) = 0; + virtual bool IsDynamic4ArgNameAndIndex(const std::string&, int32_t) const = 0; + virtual bool* MutIsDynamic4ArgNameAndIndex(const std::string&, int32_t) = 0; virtual int64_t parallel_num() const = 0; diff --git a/oneflow/core/framework/op_expr.cpp b/oneflow/core/framework/op_expr.cpp index 13113237061..4fc3ff4e6c2 100644 --- a/oneflow/core/framework/op_expr.cpp +++ b/oneflow/core/framework/op_expr.cpp @@ -251,21 +251,37 @@ class UserOpExprInferContext : public user_op::InferContext { } const DataType& InputDType(const std::string& arg_name, int32_t index) const override { - return *const_cast(this)->Dtype4ArgNameAndIndex(arg_name, index); + return Dtype4ArgNameAndIndex(arg_name, index); } - DataType* OutputDType(const std::string& arg_name, int32_t index) override { + const DataType& OutputDType(const std::string& arg_name, int32_t index) const override { return Dtype4ArgNameAndIndex(arg_name, index); } - DataType* Dtype4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { + DataType* MutOutputDType(const std::string& arg_name, int32_t index) override { + return MutDtype4ArgNameAndIndex(arg_name, index); + } + const DataType& Dtype4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { + return const_cast(this) + ->TensorDesc4ArgNameAndIndex(arg_name, index) + ->data_type(); + } + DataType* MutDtype4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { return TensorDesc4ArgNameAndIndex(arg_name, index)->mut_data_type(); } bool InputIsDynamic(const std::string& arg_name, int32_t index) const override { - return *const_cast(this)->IsDynamic4ArgNameAndIndex(arg_name, index); + return IsDynamic4ArgNameAndIndex(arg_name, index); } - bool* OutputIsDynamic(const std::string& arg_name, int32_t index) override { + bool OutputIsDynamic(const std::string& arg_name, int32_t index) const override { return IsDynamic4ArgNameAndIndex(arg_name, index); } - bool* IsDynamic4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { + bool* MutOutputIsDynamic(const std::string& arg_name, int32_t index) override { + return MutIsDynamic4ArgNameAndIndex(arg_name, index); + } + bool IsDynamic4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { + return const_cast(this) + ->TensorDesc4ArgNameAndIndex(arg_name, index) + ->is_dynamic(); + } + bool* MutIsDynamic4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { return TensorDesc4ArgNameAndIndex(arg_name, index)->mut_is_dynamic(); } const std::string& input(const std::string& arg_name, int32_t index) const override { diff --git a/oneflow/core/kernel/user_kernel.cpp b/oneflow/core/kernel/user_kernel.cpp index 12c40c20d2a..8467f5149cb 100644 --- a/oneflow/core/kernel/user_kernel.cpp +++ b/oneflow/core/kernel/user_kernel.cpp @@ -279,21 +279,37 @@ class UserKernelOpInferContext : public user_op::InferContext { return TensorDesc4ArgNameAndIndex(arg_name, index)->mut_stride(); } const DataType& InputDType(const std::string& arg_name, int32_t index) const override { - return *const_cast(this)->Dtype4ArgNameAndIndex(arg_name, index); + return Dtype4ArgNameAndIndex(arg_name, index); } - DataType* OutputDType(const std::string& arg_name, int32_t index) override { + const DataType& OutputDType(const std::string& arg_name, int32_t index) const override { return Dtype4ArgNameAndIndex(arg_name, index); } - DataType* Dtype4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { + DataType* MutOutputDType(const std::string& arg_name, int32_t index) override { + return MutDtype4ArgNameAndIndex(arg_name, index); + } + const DataType& Dtype4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { + return const_cast(this) + ->TensorDesc4ArgNameAndIndex(arg_name, index) + ->data_type(); + } + DataType* MutDtype4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { return TensorDesc4ArgNameAndIndex(arg_name, index)->mut_data_type(); } bool InputIsDynamic(const std::string& arg_name, int32_t index) const override { - return *const_cast(this)->IsDynamic4ArgNameAndIndex(arg_name, index); + return IsDynamic4ArgNameAndIndex(arg_name, index); } - bool* OutputIsDynamic(const std::string& arg_name, int32_t index) override { + bool OutputIsDynamic(const std::string& arg_name, int32_t index) const override { return IsDynamic4ArgNameAndIndex(arg_name, index); } - bool* IsDynamic4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { + bool* MutOutputIsDynamic(const std::string& arg_name, int32_t index) override { + return MutIsDynamic4ArgNameAndIndex(arg_name, index); + } + bool IsDynamic4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { + return const_cast(this) + ->TensorDesc4ArgNameAndIndex(arg_name, index) + ->is_dynamic(); + } + bool* MutIsDynamic4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { return TensorDesc4ArgNameAndIndex(arg_name, index)->mut_is_dynamic(); } diff --git a/oneflow/core/operator/user_op.cpp b/oneflow/core/operator/user_op.cpp index e7e9d8c2d2f..126020a50cc 100644 --- a/oneflow/core/operator/user_op.cpp +++ b/oneflow/core/operator/user_op.cpp @@ -193,23 +193,39 @@ class UserOpInferContext final : public user_op::InferContext { return it->second.mut_stride(); } const DataType& InputDType(const std::string& arg_name, int32_t index) const override { - return *const_cast(this)->Dtype4ArgNameAndIndex(arg_name, index); + return Dtype4ArgNameAndIndex(arg_name, index); } - DataType* OutputDType(const std::string& arg_name, int32_t index) override { + const DataType& OutputDType(const std::string& arg_name, int32_t index) const override { return Dtype4ArgNameAndIndex(arg_name, index); } - DataType* Dtype4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { + DataType* MutOutputDType(const std::string& arg_name, int32_t index) override { + return MutDtype4ArgNameAndIndex(arg_name, index); + } + const DataType& Dtype4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { + auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index)); + if (it == arg2tensor_desc_.end()) { return DataType::kInvalidDataType; }; + return it->second.data_type(); + } + DataType* MutDtype4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index)); if (it == arg2tensor_desc_.end()) { return nullptr; }; return it->second.mut_data_type(); } bool InputIsDynamic(const std::string& arg_name, int32_t index) const override { - return *const_cast(this)->IsDynamic4ArgNameAndIndex(arg_name, index); + return IsDynamic4ArgNameAndIndex(arg_name, index); } - bool* OutputIsDynamic(const std::string& arg_name, int32_t index) override { + bool OutputIsDynamic(const std::string& arg_name, int32_t index) const override { return IsDynamic4ArgNameAndIndex(arg_name, index); } - bool* IsDynamic4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { + bool* MutOutputIsDynamic(const std::string& arg_name, int32_t index) override { + return MutIsDynamic4ArgNameAndIndex(arg_name, index); + } + bool IsDynamic4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { + auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index)); + if (it == arg2tensor_desc_.end()) { return false; }; + return it->second.is_dynamic(); + } + bool* MutIsDynamic4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index)); if (it == arg2tensor_desc_.end()) { return nullptr; }; return it->second.mut_is_dynamic(); @@ -611,10 +627,10 @@ Maybe UserOp::InferOutBlobDescs( JUST(val_->physical_tensor_desc_infer_fn(&infer_ctx)); for (const auto& pair : infer_ctx.outputs()) { BlobDesc* out_blob_desc = GetBlobDesc4BnInOp(GenRepeatedBn(pair.first, pair.second)); - out_blob_desc->set_data_type(*(infer_ctx.OutputDType(pair.first, pair.second))); + out_blob_desc->set_data_type(infer_ctx.OutputDType(pair.first, pair.second)); out_blob_desc->mut_shape() = *(infer_ctx.OutputShape(pair.first, pair.second)); out_blob_desc->mut_stride() = Stride(*(infer_ctx.OutputShape(pair.first, pair.second))); - out_blob_desc->set_is_dynamic(*infer_ctx.OutputIsDynamic(pair.first, pair.second)); + out_blob_desc->set_is_dynamic(*infer_ctx.MutOutputIsDynamic(pair.first, pair.second)); } return Maybe::Ok(); } diff --git a/oneflow/ir/oneflow-extension/extension.cpp b/oneflow/ir/oneflow-extension/extension.cpp index 9954ed6dd8d..8d504c486ac 100644 --- a/oneflow/ir/oneflow-extension/extension.cpp +++ b/oneflow/ir/oneflow-extension/extension.cpp @@ -51,7 +51,7 @@ REGISTER_USER_OP("mlir_jit") const Shape& in_shape = ctx->InputShape("in", 0); Shape* out_shape = ctx->OutputShape("out", 0); *out_shape = in_shape; - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 1); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 1); return Maybe::Ok(); }) .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { @@ -65,7 +65,7 @@ REGISTER_USER_OP("mlir_jit") return Maybe::Ok(); }) .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); }); diff --git a/oneflow/user/kernels/arg_where_kernel.cpp b/oneflow/user/kernels/arg_where_kernel.cpp index 97ffeaa015e..6413530089c 100644 --- a/oneflow/user/kernels/arg_where_kernel.cpp +++ b/oneflow/user/kernels/arg_where_kernel.cpp @@ -76,7 +76,7 @@ size_t InferTempStorageBytesSize(user_op::InferContext* ctx) { const Shape& input_shape = ctx->InputShape("input", 0); if (input_shape.NumAxes() == 0) { return 0; } const DataType& input_dtype = ctx->InputDType("input", 0); - DataType output_dtype = *ctx->OutputDType("output", 0); + DataType output_dtype = ctx->OutputDType("output", 0); return SwitchUtil::SwitchGetWorkspaceBytesSize( SwitchCase(device_type, input_dtype, output_dtype, input_shape.NumAxes()), input_shape.elem_cnt()); diff --git a/oneflow/user/kernels/fused_self_attention_query_mul_key_and_value_kernel.cu b/oneflow/user/kernels/fused_self_attention_query_mul_key_and_value_kernel.cu index 0243ac36ec7..40ae734c200 100644 --- a/oneflow/user/kernels/fused_self_attention_query_mul_key_and_value_kernel.cu +++ b/oneflow/user/kernels/fused_self_attention_query_mul_key_and_value_kernel.cu @@ -267,7 +267,7 @@ class FusedSelfAttentionQueryMulKeyAndValueGradGpuKernel final : public user_op: size_t InferTmpBufferSize(user_op::InferContext* ctx) { const Shape* value_shape = ctx->OutputShape("value", 0); - DataType value_dtype = *ctx->OutputDType("value", 0); + DataType value_dtype = ctx->OutputDType("value", 0); return value_shape->elem_cnt() * GetSizeOfDataType(value_dtype); } diff --git a/oneflow/user/kernels/stateful_opkernel.cpp b/oneflow/user/kernels/stateful_opkernel.cpp index 0808219276f..13d3d4a4eb4 100644 --- a/oneflow/user/kernels/stateful_opkernel.cpp +++ b/oneflow/user/kernels/stateful_opkernel.cpp @@ -198,26 +198,42 @@ class UserOpInferContextHelper final { } const DataType& InputDType(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { - return *Dtype4ArgNameAndIndex(call_ctx, arg_name, index); + return Dtype4ArgNameAndIndex(call_ctx, arg_name, index); } - DataType* OutputDType(eager::CallContext* call_ctx, const std::string& arg_name, - int32_t index) const { + const DataType& OutputDType(eager::CallContext* call_ctx, const std::string& arg_name, + int32_t index) const { return Dtype4ArgNameAndIndex(call_ctx, arg_name, index); } - DataType* Dtype4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, - int32_t index) const { + DataType* MutOutputDType(eager::CallContext* call_ctx, const std::string& arg_name, + int32_t index) const { + return MutDtype4ArgNameAndIndex(call_ctx, arg_name, index); + } + const DataType& Dtype4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, + int32_t index) const { + return NonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index)->data_type(); + } + DataType* MutDtype4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, + int32_t index) const { return NonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index)->mut_data_type(); } bool InputIsDynamic(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { - return *IsDynamic4ArgNameAndIndex(call_ctx, arg_name, index); + return IsDynamic4ArgNameAndIndex(call_ctx, arg_name, index); } - bool* OutputIsDynamic(eager::CallContext* call_ctx, const std::string& arg_name, - int32_t index) const { + bool OutputIsDynamic(eager::CallContext* call_ctx, const std::string& arg_name, + int32_t index) const { return IsDynamic4ArgNameAndIndex(call_ctx, arg_name, index); } - bool* IsDynamic4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, - int32_t index) const { + bool* MutOutputIsDynamic(eager::CallContext* call_ctx, const std::string& arg_name, + int32_t index) const { + return MutIsDynamic4ArgNameAndIndex(call_ctx, arg_name, index); + } + bool IsDynamic4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, + int32_t index) const { + return NonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index)->is_dynamic(); + } + bool* MutIsDynamic4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, + int32_t index) const { return NonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index)->mut_is_dynamic(); } @@ -335,20 +351,32 @@ class UserOpInferContext : public user_op::InferContext { const DataType& InputDType(const std::string& arg_name, int32_t index) const override { return helper_->InputDType(call_ctx_, arg_name, index); } - DataType* OutputDType(const std::string& arg_name, int32_t index) override { + const DataType& OutputDType(const std::string& arg_name, int32_t index) const override { return helper_->OutputDType(call_ctx_, arg_name, index); } - DataType* Dtype4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { + DataType* MutOutputDType(const std::string& arg_name, int32_t index) override { + return helper_->MutOutputDType(call_ctx_, arg_name, index); + } + const DataType& Dtype4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { return helper_->Dtype4ArgNameAndIndex(call_ctx_, arg_name, index); } + DataType* MutDtype4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { + return helper_->MutDtype4ArgNameAndIndex(call_ctx_, arg_name, index); + } bool InputIsDynamic(const std::string& arg_name, int32_t index) const override { return helper_->InputIsDynamic(call_ctx_, arg_name, index); } - bool* OutputIsDynamic(const std::string& arg_name, int32_t index) override { + bool OutputIsDynamic(const std::string& arg_name, int32_t index) const override { return helper_->OutputIsDynamic(call_ctx_, arg_name, index); } - bool* IsDynamic4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { - return helper_->IsDynamic4ArgNameAndIndex(call_ctx_, arg_name, index); + bool* MutOutputIsDynamic(const std::string& arg_name, int32_t index) override { + return helper_->MutOutputIsDynamic(call_ctx_, arg_name, index); + } + bool IsDynamic4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { + return helper_->MutIsDynamic4ArgNameAndIndex(call_ctx_, arg_name, index); + } + bool* MutIsDynamic4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { + return helper_->MutIsDynamic4ArgNameAndIndex(call_ctx_, arg_name, index); } const ArgVec& inputs() const override { return helper_->inputs(); } diff --git a/oneflow/user/ops/acc_op.cpp b/oneflow/user/ops/acc_op.cpp index 92df9df8f8e..19646bac222 100644 --- a/oneflow/user/ops/acc_op.cpp +++ b/oneflow/user/ops/acc_op.cpp @@ -31,14 +31,14 @@ namespace oneflow { } /*static*/ Maybe AccOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + *ctx->MutOutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } /*static*/ Maybe AccOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return AccOp::InferLogicalTensorDesc(ctx); } /*static*/ Maybe AccOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } /*static*/ Maybe AccOp::InferOutputBlobTimeShape( diff --git a/oneflow/user/ops/adaptive_pool_op.cpp b/oneflow/user/ops/adaptive_pool_op.cpp index 935e644ea83..a358464e0d5 100644 --- a/oneflow/user/ops/adaptive_pool_op.cpp +++ b/oneflow/user/ops/adaptive_pool_op.cpp @@ -37,7 +37,7 @@ Maybe InferFWTensorDesc(user_op::InferContext* ctx) { Maybe InferBWTensorDesc(user_op::InferContext* ctx) { *ctx->OutputShape("dx", 0) = ctx->InputShape("x", 0); - *ctx->OutputIsDynamic("dx", 0) = ctx->InputIsDynamic("x", 0); + *ctx->MutOutputIsDynamic("dx", 0) = ctx->InputIsDynamic("x", 0); return Maybe::Ok(); } @@ -63,12 +63,12 @@ Maybe BwGetSbpFn(user_op::SbpContext* ctx) { } Maybe InferFWDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("y", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } Maybe InferBWDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("dx", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/affine_grid_op.cpp b/oneflow/user/ops/affine_grid_op.cpp index 24a8c9c0dd4..1826c039c63 100644 --- a/oneflow/user/ops/affine_grid_op.cpp +++ b/oneflow/user/ops/affine_grid_op.cpp @@ -145,7 +145,7 @@ Maybe CheckAttr_(const user_op::UserOpDefWrapper& def, } /* static */ Maybe AffineGridOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("grid", 0) = ctx->InputDType("theta", 0); + *ctx->MutOutputDType("grid", 0) = ctx->InputDType("theta", 0); return Maybe::Ok(); } @@ -180,7 +180,7 @@ Maybe CheckAttr_(const user_op::UserOpDefWrapper& def, } /* static */ Maybe AffineGridGradOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("dtheta", 0) = ctx->InputDType("dgrid", 0); + *ctx->MutOutputDType("dtheta", 0) = ctx->InputDType("dgrid", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/arange_op.cpp b/oneflow/user/ops/arange_op.cpp index 73585347376..91e4073da75 100644 --- a/oneflow/user/ops/arange_op.cpp +++ b/oneflow/user/ops/arange_op.cpp @@ -105,7 +105,7 @@ namespace oneflow { } /* static */ Maybe ArangeOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->Attr("dtype"); + *ctx->MutOutputDType("out", 0) = ctx->Attr("dtype"); return Maybe::Ok(); } diff --git a/oneflow/user/ops/arg_sort_op.cpp b/oneflow/user/ops/arg_sort_op.cpp index e4ca90915ff..7b986bd360a 100644 --- a/oneflow/user/ops/arg_sort_op.cpp +++ b/oneflow/user/ops/arg_sort_op.cpp @@ -48,7 +48,7 @@ namespace oneflow { } /* static */ Maybe ArgSortOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = DataType::kInt32; + *ctx->MutOutputDType("out", 0) = DataType::kInt32; return Maybe::Ok(); } diff --git a/oneflow/user/ops/argmax_op.cpp b/oneflow/user/ops/argmax_op.cpp index 58c6581eb29..446dcb7ea96 100644 --- a/oneflow/user/ops/argmax_op.cpp +++ b/oneflow/user/ops/argmax_op.cpp @@ -38,7 +38,7 @@ namespace oneflow { } /* static */ Maybe ArgmaxOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = DataType::kInt64; + *ctx->MutOutputDType("out", 0) = DataType::kInt64; return Maybe::Ok(); } diff --git a/oneflow/user/ops/as_strided_op.cpp b/oneflow/user/ops/as_strided_op.cpp index c347a627f55..45ae191a59d 100644 --- a/oneflow/user/ops/as_strided_op.cpp +++ b/oneflow/user/ops/as_strided_op.cpp @@ -35,7 +35,7 @@ namespace oneflow { return Maybe::Ok(); } /*static*/ auto AsStridedOp::InferDataType(user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("output", 0) = ctx->InputDType("input", 0); + *ctx->MutOutputDType("output", 0) = ctx->InputDType("input", 0); return Maybe::Ok(); } @@ -54,7 +54,7 @@ namespace oneflow { return Maybe::Ok(); } /*static*/ auto AsStridedGradOp::InferDataType(user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("input", 0); + *ctx->MutOutputDType("dx", 0) = ctx->InputDType("input", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/avg_pool_op.cpp b/oneflow/user/ops/avg_pool_op.cpp index e6d1521707d..775d956d77b 100644 --- a/oneflow/user/ops/avg_pool_op.cpp +++ b/oneflow/user/ops/avg_pool_op.cpp @@ -112,12 +112,12 @@ Maybe BackwardTensorDescInferFn(user_op::InferContext* ctx) { } Maybe FwInferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("y", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } Maybe BwInferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("dx", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/bias_add_op.cpp b/oneflow/user/ops/bias_add_op.cpp index 77dfff37837..410a9401fc6 100644 --- a/oneflow/user/ops/bias_add_op.cpp +++ b/oneflow/user/ops/bias_add_op.cpp @@ -36,7 +36,7 @@ namespace oneflow { << " must match the size of tensor " << b_tensor_desc.shape().ToString() << " at dimension " << bias_add_axis; *ctx->OutputShape("out", 0) = ctx->InputShape("a", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("a", 0); + *ctx->MutOutputIsDynamic("out", 0) = ctx->InputIsDynamic("a", 0); return Maybe::Ok(); } @@ -64,7 +64,7 @@ namespace oneflow { } /* static */ Maybe BiasAddOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("a", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("a", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/binary_cross_entropy_op.cpp b/oneflow/user/ops/binary_cross_entropy_op.cpp index 0d328657660..1b3d8f60416 100644 --- a/oneflow/user/ops/binary_cross_entropy_op.cpp +++ b/oneflow/user/ops/binary_cross_entropy_op.cpp @@ -49,7 +49,7 @@ Maybe InferDataType_(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(weight_desc.data_type(), input_desc.data_type()); } - *ctx->OutputDType("out", 0) = ctx->InputDType("input", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("input", 0); return Maybe::Ok(); } @@ -82,7 +82,7 @@ Maybe InferGradDataType(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(weight_desc.data_type(), input_desc.data_type()); } - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + *ctx->MutOutputDType("dx", 0) = ctx->InputDType("dy", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/binary_cross_entropy_with_logits_op.cpp b/oneflow/user/ops/binary_cross_entropy_with_logits_op.cpp index 0a124525a60..46eb05e33de 100644 --- a/oneflow/user/ops/binary_cross_entropy_with_logits_op.cpp +++ b/oneflow/user/ops/binary_cross_entropy_with_logits_op.cpp @@ -55,7 +55,7 @@ Maybe InferDataType_(user_op::InferContext* ctx) { const auto& pos_weight_desc = ctx->InputTensorDesc("pos_weight", 0); CHECK_EQ_OR_RETURN(pos_weight_desc.data_type(), input_desc.data_type()); } - *ctx->OutputDType("out", 0) = ctx->InputDType("input", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("input", 0); return Maybe::Ok(); } @@ -96,7 +96,7 @@ Maybe InferGradDataType(user_op::InferContext* ctx) { const auto& pos_weight_desc = ctx->InputTensorDesc("pos_weight", 0); CHECK_EQ_OR_RETURN(pos_weight_desc.data_type(), input_desc.data_type()); } - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + *ctx->MutOutputDType("dx", 0) = ctx->InputDType("dy", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/binary_cross_entropy_with_logits_reduce_mean_op.cpp b/oneflow/user/ops/binary_cross_entropy_with_logits_reduce_mean_op.cpp index d32d06fb8c1..273219e85ea 100644 --- a/oneflow/user/ops/binary_cross_entropy_with_logits_reduce_mean_op.cpp +++ b/oneflow/user/ops/binary_cross_entropy_with_logits_reduce_mean_op.cpp @@ -37,7 +37,7 @@ Maybe InferFwDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& target_desc = ctx->InputTensorDesc("target", 0); CHECK_EQ_OR_RETURN(input_desc.data_type(), target_desc.data_type()) << "Input datatype should be equal to Target datatype. "; - *ctx->OutputDType("out", 0) = ctx->InputDType("input", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("input", 0); return Maybe::Ok(); } @@ -58,7 +58,7 @@ Maybe InferGradDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& target_desc = ctx->InputTensorDesc("target", 0); CHECK_EQ_OR_RETURN(input_desc.data_type(), target_desc.data_type()) << "Input datatype should be equal to Target datatype. "; - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + *ctx->MutOutputDType("dx", 0) = ctx->InputDType("dy", 0); return Maybe::Ok(); } } // namespace diff --git a/oneflow/user/ops/broadcast_div_grad_op.cpp b/oneflow/user/ops/broadcast_div_grad_op.cpp index c59b2436997..df948fb0b47 100644 --- a/oneflow/user/ops/broadcast_div_grad_op.cpp +++ b/oneflow/user/ops/broadcast_div_grad_op.cpp @@ -20,7 +20,7 @@ namespace oneflow { /* static */ Maybe BroadcastDivGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { *ctx->OutputShape("dy", 0) = ctx->InputShape("y", 0); - *ctx->OutputIsDynamic("dy", 0) = ctx->InputIsDynamic("y", 0); + *ctx->MutOutputIsDynamic("dy", 0) = ctx->InputIsDynamic("y", 0); return Maybe::Ok(); } @@ -67,7 +67,7 @@ namespace oneflow { } /* static */ Maybe BroadcastDivGradOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("dy", 0) = ctx->InputDType("y", 0); + *ctx->MutOutputDType("dy", 0) = ctx->InputDType("y", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/broadcast_like_op.cpp b/oneflow/user/ops/broadcast_like_op.cpp index 1478378ea7f..53ee3415e5f 100644 --- a/oneflow/user/ops/broadcast_like_op.cpp +++ b/oneflow/user/ops/broadcast_like_op.cpp @@ -110,7 +110,7 @@ Maybe InferTensorDesc(user_op::InferContext* ctx) { } /* static */ Maybe BroadcastLikeOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("y", 0) = ctx->InputDType("like", 0); + *ctx->MutOutputDType("y", 0) = ctx->InputDType("like", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/broadcast_pow_grad_op.cpp b/oneflow/user/ops/broadcast_pow_grad_op.cpp index 21fa575b03b..15a0ee1800e 100644 --- a/oneflow/user/ops/broadcast_pow_grad_op.cpp +++ b/oneflow/user/ops/broadcast_pow_grad_op.cpp @@ -20,7 +20,7 @@ namespace oneflow { /* static */ Maybe BroadcastPowXGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { *ctx->OutputShape("dx", 0) = ctx->InputShape("x", 0); - *ctx->OutputIsDynamic("dx", 0) = ctx->InputIsDynamic("x", 0); + *ctx->MutOutputIsDynamic("dx", 0) = ctx->InputIsDynamic("x", 0); return Maybe::Ok(); } @@ -71,13 +71,13 @@ namespace oneflow { } /* static */ Maybe BroadcastPowXGradOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("dx", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } /* static */ Maybe BroadcastPowYGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { *ctx->OutputShape("dy", 0) = ctx->InputShape("y", 0); - *ctx->OutputIsDynamic("dy", 0) = ctx->InputIsDynamic("y", 0); + *ctx->MutOutputIsDynamic("dy", 0) = ctx->InputIsDynamic("y", 0); return Maybe::Ok(); } @@ -116,7 +116,7 @@ namespace oneflow { } /* static */ Maybe BroadcastPowYGradOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("dy", 0) = ctx->InputDType("y", 0); + *ctx->MutOutputDType("dy", 0) = ctx->InputDType("y", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/buffer_op.cpp b/oneflow/user/ops/buffer_op.cpp index eb8abde1ee6..1dc638d3621 100644 --- a/oneflow/user/ops/buffer_op.cpp +++ b/oneflow/user/ops/buffer_op.cpp @@ -20,7 +20,7 @@ namespace oneflow { /* static */ Maybe IdentityBufferOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + *ctx->MutOutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } @@ -41,7 +41,7 @@ namespace oneflow { } /* static */ Maybe IdentityBufferOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/cast_like_op.cpp b/oneflow/user/ops/cast_like_op.cpp index c4d41a00be8..714b00ac013 100644 --- a/oneflow/user/ops/cast_like_op.cpp +++ b/oneflow/user/ops/cast_like_op.cpp @@ -20,7 +20,7 @@ namespace oneflow { /* static */ Maybe CastLikeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + *ctx->MutOutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/cast_to_static_shape_op.cpp b/oneflow/user/ops/cast_to_static_shape_op.cpp index 20843124a24..2b73703db8e 100644 --- a/oneflow/user/ops/cast_to_static_shape_op.cpp +++ b/oneflow/user/ops/cast_to_static_shape_op.cpp @@ -46,7 +46,7 @@ namespace oneflow { } /* static */ Maybe CastToStaticShapeOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("output", 0) = ctx->InputDType("input", 0); + *ctx->MutOutputDType("output", 0) = ctx->InputDType("input", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/cast_to_tick_op.cpp b/oneflow/user/ops/cast_to_tick_op.cpp index bb76f5887e6..8af44e9ff69 100644 --- a/oneflow/user/ops/cast_to_tick_op.cpp +++ b/oneflow/user/ops/cast_to_tick_op.cpp @@ -53,7 +53,7 @@ namespace oneflow { } /* static */ Maybe CastToTickOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/categorical_ordinal_encode_op.cpp b/oneflow/user/ops/categorical_ordinal_encode_op.cpp index ca2b4533826..b4bf743cce8 100644 --- a/oneflow/user/ops/categorical_ordinal_encode_op.cpp +++ b/oneflow/user/ops/categorical_ordinal_encode_op.cpp @@ -72,7 +72,7 @@ namespace oneflow { CHECK_OR_RETURN(IsIndexDataType(data_type)); CHECK_EQ_OR_RETURN(ctx->InputDType("table", 0), data_type); CHECK_EQ_OR_RETURN(ctx->InputDType("size", 0), data_type); - *ctx->OutputDType("out", 0) = data_type; + *ctx->MutOutputDType("out", 0) = data_type; return Maybe::Ok(); } diff --git a/oneflow/user/ops/celu_op.cpp b/oneflow/user/ops/celu_op.cpp index 60d48152434..55252cfdc7d 100644 --- a/oneflow/user/ops/celu_op.cpp +++ b/oneflow/user/ops/celu_op.cpp @@ -36,7 +36,7 @@ namespace oneflow { } /* static */ Maybe CeluOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } @@ -67,7 +67,7 @@ namespace oneflow { /* static */ Maybe CeluGradOp::InferDataType(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), ctx->InputDType("x", 0)); - *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("dx", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/clip_by_value_op.cpp b/oneflow/user/ops/clip_by_value_op.cpp index f216e077816..d0b9e4c0a8a 100644 --- a/oneflow/user/ops/clip_by_value_op.cpp +++ b/oneflow/user/ops/clip_by_value_op.cpp @@ -56,12 +56,12 @@ Maybe GetClipGradSbpSignature(user_op::SbpContext* ctx) { } Maybe InferClipTensorDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("y", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } Maybe InferClipGradDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("dx", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/combined_margin_loss_op.cpp b/oneflow/user/ops/combined_margin_loss_op.cpp index 72854a53928..59e6e34017d 100644 --- a/oneflow/user/ops/combined_margin_loss_op.cpp +++ b/oneflow/user/ops/combined_margin_loss_op.cpp @@ -25,7 +25,7 @@ namespace oneflow { CHECK_EQ_OR_RETURN(label.shape().At(0), x.shape().At(0)); CHECK_GE_OR_RETURN(x.shape().NumAxes(), 2); *ctx->OutputShape("y", 0) = ctx->InputShape("x", 0); - *ctx->IsDynamic4ArgNameAndIndex("y", 0) = ctx->InputIsDynamic("x", 0); + *ctx->MutIsDynamic4ArgNameAndIndex("y", 0) = ctx->InputIsDynamic("x", 0); *theta->mut_is_dynamic() = x.is_dynamic(); *theta->mut_shape() = label.shape(); return Maybe::Ok(); @@ -59,8 +59,8 @@ namespace oneflow { } /* static */ Maybe CombinedMarginLossOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); - *ctx->OutputDType("theta", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("y", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("theta", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } @@ -73,7 +73,7 @@ namespace oneflow { CHECK_EQ_OR_RETURN(label.shape().At(0), theta.shape().At(0)); CHECK_GE_OR_RETURN(dy.shape().NumAxes(), 2); *ctx->OutputShape("dx", 0) = ctx->InputShape("dy", 0); - *ctx->IsDynamic4ArgNameAndIndex("dx", 0) = ctx->InputIsDynamic("dy", 0); + *ctx->MutIsDynamic4ArgNameAndIndex("dx", 0) = ctx->InputIsDynamic("dy", 0); return Maybe::Ok(); } @@ -99,7 +99,7 @@ namespace oneflow { } /* static */ Maybe CombinedMarginLossGradOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + *ctx->MutOutputDType("dx", 0) = ctx->InputDType("dy", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/constant_op.cpp b/oneflow/user/ops/constant_op.cpp index 62d9bdcc050..854740518c1 100644 --- a/oneflow/user/ops/constant_op.cpp +++ b/oneflow/user/ops/constant_op.cpp @@ -46,7 +46,7 @@ namespace oneflow { } /* static */ Maybe ConstantOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->Attr("dtype"); + *ctx->MutOutputDType("out", 0) = ctx->Attr("dtype"); return Maybe::Ok(); } diff --git a/oneflow/user/ops/conv_op.cpp b/oneflow/user/ops/conv_op.cpp index 64940f4d2da..8097b6829e4 100644 --- a/oneflow/user/ops/conv_op.cpp +++ b/oneflow/user/ops/conv_op.cpp @@ -248,7 +248,7 @@ Maybe GenerateBackwardOpConf4Conv(const user_op::UserOpWrapper& op, user_o } /* static */ Maybe Conv1DOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } @@ -270,7 +270,7 @@ Maybe GenerateBackwardOpConf4Conv(const user_op::UserOpWrapper& op, user_o } /* static */ Maybe Conv2DOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } @@ -292,7 +292,7 @@ Maybe GenerateBackwardOpConf4Conv(const user_op::UserOpWrapper& op, user_o } /* static */ Maybe Conv3DOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } @@ -309,7 +309,7 @@ Maybe GenerateBackwardOpConf4Conv(const user_op::UserOpWrapper& op, user_o CHECK_EQ_OR_RETURN(add_to_output.shape(), x_like.shape()); } *ctx->OutputShape("dx", 0) = ctx->InputShape("x_like", 0); - *ctx->OutputIsDynamic("dx", 0) = ctx->InputIsDynamic("x_like", 0); + *ctx->MutOutputIsDynamic("dx", 0) = ctx->InputIsDynamic("x_like", 0); return Maybe::Ok(); } @@ -342,7 +342,7 @@ Maybe GenerateBackwardOpConf4Conv(const user_op::UserOpWrapper& op, user_o const user_op::TensorDesc& add_to_output = ctx->InputTensorDesc("_add_to_output", 0); CHECK_EQ_OR_RETURN(add_to_output.data_type(), x_like.data_type()); } - *ctx->OutputDType("dx", 0) = ctx->InputDType("x_like", 0); + *ctx->MutOutputDType("dx", 0) = ctx->InputDType("x_like", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/copy_op.cpp b/oneflow/user/ops/copy_op.cpp index 6b7d5f994f2..ab0b8d23fd5 100644 --- a/oneflow/user/ops/copy_op.cpp +++ b/oneflow/user/ops/copy_op.cpp @@ -44,7 +44,7 @@ Maybe> MakeCopyStream(const Symbol& in_device, /* static */ Maybe CopyOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); *ctx->OutputStride("out", 0) = ctx->InputStride("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + *ctx->MutOutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } @@ -65,7 +65,7 @@ Maybe> MakeCopyStream(const Symbol& in_device, } /* static */ Maybe CopyOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/ctc_loss_op.cpp b/oneflow/user/ops/ctc_loss_op.cpp index b8dee1ad9cc..856a321570b 100644 --- a/oneflow/user/ops/ctc_loss_op.cpp +++ b/oneflow/user/ops/ctc_loss_op.cpp @@ -57,8 +57,8 @@ namespace oneflow { } /* static */ Maybe CtcLossOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("loss", 0) = ctx->InputDType("log_probs", 0); - *ctx->OutputDType("alpha", 0) = ctx->InputDType("log_probs", 0); + *ctx->MutOutputDType("loss", 0) = ctx->InputDType("log_probs", 0); + *ctx->MutOutputDType("alpha", 0) = ctx->InputDType("log_probs", 0); return Maybe::Ok(); } @@ -101,7 +101,7 @@ namespace oneflow { } /* static */ Maybe CtcLossGradOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("grad", 0) = ctx->InputDType("log_probs", 0); + *ctx->MutOutputDType("grad", 0) = ctx->InputDType("log_probs", 0); return Maybe::Ok(); } @@ -130,8 +130,8 @@ namespace oneflow { } /* static */ Maybe CtcGreedyDecoderOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("decoded", 0) = ctx->InputDType("input_lengths", 0); - *ctx->OutputDType("neg_sum_logits", 0) = ctx->InputDType("log_probs", 0); + *ctx->MutOutputDType("decoded", 0) = ctx->InputDType("input_lengths", 0); + *ctx->MutOutputDType("neg_sum_logits", 0) = ctx->InputDType("log_probs", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/cublas_fused_mlp_grad_op.cpp b/oneflow/user/ops/cublas_fused_mlp_grad_op.cpp index cf4fd9d3bcd..cac7ea85d75 100644 --- a/oneflow/user/ops/cublas_fused_mlp_grad_op.cpp +++ b/oneflow/user/ops/cublas_fused_mlp_grad_op.cpp @@ -41,10 +41,10 @@ Maybe InferDataType4MatmulBackward(user_op::InferContext* ctx) { "Because last layer's bias_grad is computed by ReduceSum. "; const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); for (int idx = weight_num - 1; idx >= 0; idx--) { - *ctx->OutputDType("d_weights", idx) = dy_desc.data_type(); - *ctx->OutputDType("d_biases", idx) = dy_desc.data_type(); + *ctx->MutOutputDType("d_weights", idx) = dy_desc.data_type(); + *ctx->MutOutputDType("d_biases", idx) = dy_desc.data_type(); } - *ctx->OutputDType("d_x", 0) = dy_desc.data_type(); + *ctx->MutOutputDType("d_x", 0) = dy_desc.data_type(); return Maybe::Ok(); } diff --git a/oneflow/user/ops/cum_ops.cpp b/oneflow/user/ops/cum_ops.cpp index 265a201119d..24fcf81087a 100644 --- a/oneflow/user/ops/cum_ops.cpp +++ b/oneflow/user/ops/cum_ops.cpp @@ -37,7 +37,7 @@ Maybe CumsumOp::GetSbp(user_op::SbpContext* ctx) { } Maybe CumsumOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("y", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } @@ -91,7 +91,7 @@ Maybe CumProdOp::GetSbp(user_op::SbpContext* ctx) { } Maybe CumProdOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("y", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } @@ -113,7 +113,7 @@ Maybe CumProdGradOp::GetSbp(user_op::SbpContext* ctx) { } Maybe CumProdGradOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + *ctx->MutOutputDType("dx", 0) = ctx->InputDType("dy", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/data_shuffle_op.cpp b/oneflow/user/ops/data_shuffle_op.cpp index e8e3ebfa9fa..0a22238f442 100644 --- a/oneflow/user/ops/data_shuffle_op.cpp +++ b/oneflow/user/ops/data_shuffle_op.cpp @@ -48,13 +48,13 @@ namespace oneflow { } /* static */ Maybe UniqueKeyValuePairOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("num_unique", 0) = DataType::kInt32; - *ctx->OutputDType("unique_keys", 0) = ctx->InputDType("keys", 0); - *ctx->OutputDType("inverse_indices", 0) = DataType::kInt32; + *ctx->MutOutputDType("num_unique", 0) = DataType::kInt32; + *ctx->MutOutputDType("unique_keys", 0) = ctx->InputDType("keys", 0); + *ctx->MutOutputDType("inverse_indices", 0) = DataType::kInt32; if (ctx->has_input("values", 0)) { - *ctx->OutputDType("unique_values", 0) = ctx->InputDType("values", 0); + *ctx->MutOutputDType("unique_values", 0) = ctx->InputDType("values", 0); } else { - *ctx->OutputDType("unique_values", 0) = DataType::kInt32; + *ctx->MutOutputDType("unique_values", 0) = DataType::kInt32; } return Maybe::Ok(); } @@ -98,15 +98,15 @@ namespace oneflow { } /* static */ Maybe IdShuffleOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("num_unique_matrix", 0) = DataType::kUInt32; - *ctx->OutputDType("inverse_unique_partition_indices", 0) = DataType::kUInt32; - *ctx->OutputDType("cur_rank_num_unique", 0) = DataType::kUInt32; - *ctx->OutputDType("cur_rank_unique_ids", 0) = ctx->InputDType("ids", 0); - *ctx->OutputDType("cur_rank_inverse_indices", 0) = DataType::kUInt32; + *ctx->MutOutputDType("num_unique_matrix", 0) = DataType::kUInt32; + *ctx->MutOutputDType("inverse_unique_partition_indices", 0) = DataType::kUInt32; + *ctx->MutOutputDType("cur_rank_num_unique", 0) = DataType::kUInt32; + *ctx->MutOutputDType("cur_rank_unique_ids", 0) = ctx->InputDType("ids", 0); + *ctx->MutOutputDType("cur_rank_inverse_indices", 0) = DataType::kUInt32; if (ctx->has_input("table_ids", 0)) { - *ctx->OutputDType("cur_rank_unique_table_ids", 0) = ctx->InputDType("table_ids", 0); + *ctx->MutOutputDType("cur_rank_unique_table_ids", 0) = ctx->InputDType("table_ids", 0); } else { - *ctx->OutputDType("cur_rank_unique_table_ids", 0) = DataType::kUInt8; + *ctx->MutOutputDType("cur_rank_unique_table_ids", 0) = DataType::kUInt8; } return Maybe::Ok(); } @@ -160,7 +160,7 @@ namespace oneflow { CHECK_OR_RETURN(ctx->InputDType("num_unique_matrix", 0) == DataType::kUInt32); CHECK_OR_RETURN(ctx->InputDType("cur_rank_inverse_indices", 0) == DataType::kUInt32); CHECK_OR_RETURN(ctx->InputDType("inverse_unique_partition_indices", 0) == DataType::kUInt32); - *ctx->OutputDType("embeddings", 0) = ctx->InputDType("cur_rank_embeddings", 0); + *ctx->MutOutputDType("embeddings", 0) = ctx->InputDType("cur_rank_embeddings", 0); return Maybe::Ok(); } @@ -201,7 +201,7 @@ namespace oneflow { CHECK_OR_RETURN(ctx->InputDType("num_unique_matrix", 0) == DataType::kUInt32); CHECK_OR_RETURN(ctx->InputDType("cur_rank_inverse_indices", 0) == DataType::kUInt32); CHECK_OR_RETURN(ctx->InputDType("inverse_unique_partition_indices", 0) == DataType::kUInt32); - *ctx->OutputDType("cur_rank_unique_embedding_grad", 0) = ctx->InputDType("embedding_grad", 0); + *ctx->MutOutputDType("cur_rank_unique_embedding_grad", 0) = ctx->InputDType("embedding_grad", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/deconv_op.cpp b/oneflow/user/ops/deconv_op.cpp index 43fd14bb16f..fe943945b2a 100644 --- a/oneflow/user/ops/deconv_op.cpp +++ b/oneflow/user/ops/deconv_op.cpp @@ -85,7 +85,7 @@ Maybe InferTensorDesc4DeConv(user_op::InferContext* ctx) { } Maybe InferDataType_(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/diag_op.cpp b/oneflow/user/ops/diag_op.cpp index 93c9cf1b27e..624b29a07c5 100644 --- a/oneflow/user/ops/diag_op.cpp +++ b/oneflow/user/ops/diag_op.cpp @@ -57,7 +57,7 @@ namespace oneflow { } /* static */ Maybe DiagOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } @@ -79,7 +79,7 @@ namespace oneflow { } /* static */ Maybe DiagGradOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + *ctx->MutOutputDType("dx", 0) = ctx->InputDType("dy", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/diagonal_op.cpp b/oneflow/user/ops/diagonal_op.cpp index c7bed93b172..2511e6717e5 100644 --- a/oneflow/user/ops/diagonal_op.cpp +++ b/oneflow/user/ops/diagonal_op.cpp @@ -52,7 +52,7 @@ namespace oneflow { } /* static */ Maybe DiagonalOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } @@ -74,7 +74,7 @@ namespace oneflow { } /* static */ Maybe DiagonalGradOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + *ctx->MutOutputDType("dx", 0) = ctx->InputDType("dy", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/dim_scatter_ops.cpp b/oneflow/user/ops/dim_scatter_ops.cpp index 60ef6283774..42759138456 100644 --- a/oneflow/user/ops/dim_scatter_ops.cpp +++ b/oneflow/user/ops/dim_scatter_ops.cpp @@ -185,14 +185,14 @@ Maybe InferDtype(user_op::InferContext* ctx) { } else { CHECK_EQ_OR_RETURN(ctx->InputDType("like", 0), ctx->InputDType("src", 0)); } - *ctx->OutputDType("output", 0) = ctx->InputDType("src", 0); + *ctx->MutOutputDType("output", 0) = ctx->InputDType("src", 0); return Maybe::Ok(); } Maybe InferScalarDtype(user_op::InferContext* ctx) { const user_op::TensorDesc& index = ctx->InputTensorDesc("index", 0); CHECK_OR_RETURN(IsIndexDataType(index.data_type())); - *ctx->OutputDType("output", 0) = ctx->InputDType("input", 0); + *ctx->MutOutputDType("output", 0) = ctx->InputDType("input", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/distributions/normal_op.cpp b/oneflow/user/ops/distributions/normal_op.cpp index 736a70e5d0b..41935f749dd 100644 --- a/oneflow/user/ops/distributions/normal_op.cpp +++ b/oneflow/user/ops/distributions/normal_op.cpp @@ -53,7 +53,7 @@ namespace oneflow { /* static */ Maybe NormalOp::InferDataType(user_op::InferContext* ctx) { auto dtype = ctx->Attr("dtype"); - *ctx->OutputDType("out", 0) = dtype; + *ctx->MutOutputDType("out", 0) = dtype; return Maybe::Ok(); } diff --git a/oneflow/user/ops/distributions/uniform_int_op.cpp b/oneflow/user/ops/distributions/uniform_int_op.cpp index f01bb710f3c..227ec609b4a 100644 --- a/oneflow/user/ops/distributions/uniform_int_op.cpp +++ b/oneflow/user/ops/distributions/uniform_int_op.cpp @@ -56,7 +56,7 @@ namespace oneflow { /* static */ Maybe UniformIntOp::InferDataType(user_op::InferContext* ctx) { auto dtype = ctx->Attr("dtype"); - *ctx->OutputDType("out", 0) = dtype; + *ctx->MutOutputDType("out", 0) = dtype; return Maybe::Ok(); } diff --git a/oneflow/user/ops/distributions/uniform_op.cpp b/oneflow/user/ops/distributions/uniform_op.cpp index b7d566aac49..7fea02b8987 100644 --- a/oneflow/user/ops/distributions/uniform_op.cpp +++ b/oneflow/user/ops/distributions/uniform_op.cpp @@ -56,7 +56,7 @@ namespace oneflow { /* static */ Maybe UniformOp::InferDataType(user_op::InferContext* ctx) { auto dtype = ctx->Attr("dtype"); - *ctx->OutputDType("out", 0) = dtype; + *ctx->MutOutputDType("out", 0) = dtype; return Maybe::Ok(); } diff --git a/oneflow/user/ops/dot_op.cpp b/oneflow/user/ops/dot_op.cpp index 080a8cff539..27961e60e5b 100644 --- a/oneflow/user/ops/dot_op.cpp +++ b/oneflow/user/ops/dot_op.cpp @@ -52,7 +52,7 @@ namespace oneflow { CHECK_OR_RETURN(x.data_type() == y.data_type()) << Error::RuntimeError() << "expected both vectors to have same dtype, but found " << DataType_Name(x.data_type()) << " and " << DataType_Name(y.data_type()); - *ctx->OutputDType("out", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/dropout_op.cpp b/oneflow/user/ops/dropout_op.cpp index c23d2ef28af..cd5500dc079 100644 --- a/oneflow/user/ops/dropout_op.cpp +++ b/oneflow/user/ops/dropout_op.cpp @@ -22,7 +22,7 @@ namespace oneflow { const Shape& in_shape = ctx->InputShape("in", 0); *ctx->OutputShape("out", 0) = in_shape; *ctx->OutputShape("mask", 0) = in_shape; - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + *ctx->MutOutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } @@ -46,15 +46,15 @@ namespace oneflow { } /* static */ Maybe DropoutOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - *ctx->OutputDType("mask", 0) = DataType::kBool; + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("mask", 0) = DataType::kBool; return Maybe::Ok(); } /* static */ Maybe DropoutGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& dy_shape = ctx->InputShape("dy", 0); *ctx->OutputShape("dx", 0) = dy_shape; - *ctx->OutputIsDynamic("dx", 0) = ctx->InputIsDynamic("dy", 0); + *ctx->MutOutputIsDynamic("dx", 0) = ctx->InputIsDynamic("dy", 0); CHECK_EQ_OR_RETURN(ctx->InputShape("mask", 0), dy_shape); return Maybe::Ok(); } @@ -83,7 +83,7 @@ namespace oneflow { } /* static */ Maybe DropoutGradOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + *ctx->MutOutputDType("dx", 0) = ctx->InputDType("dy", 0); CHECK_EQ_OR_RETURN(ctx->InputDType("mask", 0), DataType::kBool); return Maybe::Ok(); } @@ -117,7 +117,7 @@ namespace oneflow { } /* static */ Maybe RandomMaskLikeOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = DataType::kBool; + *ctx->MutOutputDType("out", 0) = DataType::kBool; return Maybe::Ok(); } diff --git a/oneflow/user/ops/eager_b_to_s_op.cpp b/oneflow/user/ops/eager_b_to_s_op.cpp index 00cb6aee242..aa17d53332a 100644 --- a/oneflow/user/ops/eager_b_to_s_op.cpp +++ b/oneflow/user/ops/eager_b_to_s_op.cpp @@ -56,7 +56,7 @@ namespace oneflow { } /* static */ Maybe EagerBToSOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/eager_nccl_ops.cpp b/oneflow/user/ops/eager_nccl_ops.cpp index 5f574a7b1be..77ad74a87b6 100644 --- a/oneflow/user/ops/eager_nccl_ops.cpp +++ b/oneflow/user/ops/eager_nccl_ops.cpp @@ -38,7 +38,7 @@ namespace oneflow { } /* static */ Maybe EagerNcclAllReduceOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } @@ -64,7 +64,7 @@ namespace oneflow { } /* static */ Maybe EagerNcclBroadcastOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } @@ -109,7 +109,7 @@ namespace oneflow { } /* static */ Maybe EagerNcclReduceOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } @@ -169,7 +169,7 @@ namespace oneflow { } /* static */ Maybe EagerNcclReduceScatterOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } @@ -180,7 +180,7 @@ namespace oneflow { /* static */ Maybe EagerNcclAllGatherOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + *ctx->MutOutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } @@ -216,7 +216,7 @@ namespace oneflow { } /* static */ Maybe EagerNcclAllGatherOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } @@ -227,7 +227,7 @@ namespace oneflow { /* static */ Maybe EagerNcclS2sOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + *ctx->MutOutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } @@ -261,7 +261,7 @@ namespace oneflow { } /* static */ Maybe EagerNcclS2sOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/eager_p_to_b_op.cpp b/oneflow/user/ops/eager_p_to_b_op.cpp index f503dfcefd9..b3d19e34940 100644 --- a/oneflow/user/ops/eager_p_to_b_op.cpp +++ b/oneflow/user/ops/eager_p_to_b_op.cpp @@ -41,7 +41,7 @@ namespace oneflow { } /* static */ Maybe EagerPToBOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/eager_p_to_s_op.cpp b/oneflow/user/ops/eager_p_to_s_op.cpp index d05bb50df12..440fd6b4229 100644 --- a/oneflow/user/ops/eager_p_to_s_op.cpp +++ b/oneflow/user/ops/eager_p_to_s_op.cpp @@ -55,7 +55,7 @@ namespace oneflow { } /* static */ Maybe EagerPToSOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/eager_s_to_b_op.cpp b/oneflow/user/ops/eager_s_to_b_op.cpp index e59d98bb520..9aa68b0d56c 100644 --- a/oneflow/user/ops/eager_s_to_b_op.cpp +++ b/oneflow/user/ops/eager_s_to_b_op.cpp @@ -41,7 +41,7 @@ namespace oneflow { } /* static */ Maybe EagerSToBOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/eager_s_to_p_op.cpp b/oneflow/user/ops/eager_s_to_p_op.cpp index 711c8d84501..01c9a7d1137 100644 --- a/oneflow/user/ops/eager_s_to_p_op.cpp +++ b/oneflow/user/ops/eager_s_to_p_op.cpp @@ -41,7 +41,7 @@ namespace oneflow { } /* static */ Maybe EagerSToPOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/eager_s_to_s_op.cpp b/oneflow/user/ops/eager_s_to_s_op.cpp index f2ec6bc933d..cfafa78dcb3 100644 --- a/oneflow/user/ops/eager_s_to_s_op.cpp +++ b/oneflow/user/ops/eager_s_to_s_op.cpp @@ -55,7 +55,7 @@ namespace oneflow { } /* static */ Maybe EagerNaiveSToSOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/eager_symmetric_s_to_p_op.cpp b/oneflow/user/ops/eager_symmetric_s_to_p_op.cpp index 1767d96e9f4..db6ed285f46 100644 --- a/oneflow/user/ops/eager_symmetric_s_to_p_op.cpp +++ b/oneflow/user/ops/eager_symmetric_s_to_p_op.cpp @@ -65,7 +65,7 @@ namespace oneflow { } /* static */ Maybe EagerSymmetricSToPOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/elu_op.cpp b/oneflow/user/ops/elu_op.cpp index 9de85d34655..0c142dfe4eb 100644 --- a/oneflow/user/ops/elu_op.cpp +++ b/oneflow/user/ops/elu_op.cpp @@ -36,7 +36,7 @@ namespace oneflow { } /* static */ Maybe EluOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } @@ -67,7 +67,7 @@ namespace oneflow { /* static */ Maybe EluGradOp::InferDataType(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), ctx->InputDType("x", 0)); - *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("dx", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/embedding_op.cpp b/oneflow/user/ops/embedding_op.cpp index 5d124cac674..cf823b09566 100644 --- a/oneflow/user/ops/embedding_op.cpp +++ b/oneflow/user/ops/embedding_op.cpp @@ -33,7 +33,7 @@ namespace oneflow { } /*static*/ Maybe EmbeddingRenormOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } @@ -73,7 +73,7 @@ namespace oneflow { } /*static*/ Maybe EmbeddingOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("weight", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("weight", 0); return Maybe::Ok(); } @@ -126,7 +126,7 @@ namespace oneflow { /*static*/ Maybe EmbeddingGradOp::InferDataType(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(ctx->InputDType("weight", 0), ctx->InputDType("dy", 0)) << "input grad has different type with weight"; - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + *ctx->MutOutputDType("dx", 0) = ctx->InputDType("dy", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/empty_op.cpp b/oneflow/user/ops/empty_op.cpp index 92582ad145d..ad8ac2e0702 100644 --- a/oneflow/user/ops/empty_op.cpp +++ b/oneflow/user/ops/empty_op.cpp @@ -66,7 +66,7 @@ Maybe> MakeEmptyStream(const Symbol& out_device, const bo } /* static */ Maybe EmptyOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->Attr("dtype"); + *ctx->MutOutputDType("out", 0) = ctx->Attr("dtype"); return Maybe::Ok(); } diff --git a/oneflow/user/ops/erfinv_op.cpp b/oneflow/user/ops/erfinv_op.cpp index 708e50c89c6..274b9d1ba42 100644 --- a/oneflow/user/ops/erfinv_op.cpp +++ b/oneflow/user/ops/erfinv_op.cpp @@ -38,7 +38,7 @@ namespace oneflow { } /* static */ Maybe ErfInvOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("y", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/expand_dims_op.cpp b/oneflow/user/ops/expand_dims_op.cpp index f5031f7a1b3..e0468550c40 100644 --- a/oneflow/user/ops/expand_dims_op.cpp +++ b/oneflow/user/ops/expand_dims_op.cpp @@ -65,7 +65,7 @@ int32_t TransformNegativeAxisToPositive(int32_t axis, const int32_t num_axes) { } /* static */ Maybe ExpandDimsOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/expand_op.cpp b/oneflow/user/ops/expand_op.cpp index 9e8cfd5c2ef..135704ec299 100644 --- a/oneflow/user/ops/expand_op.cpp +++ b/oneflow/user/ops/expand_op.cpp @@ -71,7 +71,7 @@ namespace oneflow { } /* static */ Maybe ExpandOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } @@ -125,7 +125,7 @@ namespace oneflow { } /* static */ Maybe ExpandGradOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/eye_op.cpp b/oneflow/user/ops/eye_op.cpp index 077758b2452..9edf9a6aa0b 100644 --- a/oneflow/user/ops/eye_op.cpp +++ b/oneflow/user/ops/eye_op.cpp @@ -35,7 +35,7 @@ namespace oneflow { } /* static */ Maybe EyeOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->Attr("dtype"); + *ctx->MutOutputDType("out", 0) = ctx->Attr("dtype"); return Maybe::Ok(); } diff --git a/oneflow/user/ops/fake_quantization_op.cpp b/oneflow/user/ops/fake_quantization_op.cpp index fbe6a7d8ca6..8f79529af3e 100644 --- a/oneflow/user/ops/fake_quantization_op.cpp +++ b/oneflow/user/ops/fake_quantization_op.cpp @@ -104,7 +104,7 @@ namespace oneflow { } /* static */ Maybe FakeQuantizationOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/fill_op.cpp b/oneflow/user/ops/fill_op.cpp index 854e9a311e7..37fd7655cb3 100644 --- a/oneflow/user/ops/fill_op.cpp +++ b/oneflow/user/ops/fill_op.cpp @@ -40,7 +40,7 @@ namespace oneflow { } /* static */ Maybe FillOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } @@ -75,7 +75,7 @@ namespace oneflow { } /* static */ Maybe FillTensorOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/flatten_op.cpp b/oneflow/user/ops/flatten_op.cpp index 7ac839b479c..c7798d56fb5 100644 --- a/oneflow/user/ops/flatten_op.cpp +++ b/oneflow/user/ops/flatten_op.cpp @@ -82,7 +82,7 @@ namespace oneflow { } /* static */ Maybe FlattenOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/flip_op.cpp b/oneflow/user/ops/flip_op.cpp index b7d750552a9..24af3a3f4b7 100644 --- a/oneflow/user/ops/flip_op.cpp +++ b/oneflow/user/ops/flip_op.cpp @@ -49,7 +49,7 @@ namespace oneflow { return Maybe::Ok(); } /*static*/ auto FlipOp::InferDataType(user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("y", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/fused_bias_add_op.cpp b/oneflow/user/ops/fused_bias_add_op.cpp index 46f9394ff18..ecad1ee62c3 100644 --- a/oneflow/user/ops/fused_bias_add_op.cpp +++ b/oneflow/user/ops/fused_bias_add_op.cpp @@ -28,7 +28,7 @@ namespace oneflow { CHECK_LT_OR_RETURN(bias_add_axis, a_tensor_desc.shape().NumAxes()); CHECK_EQ_OR_RETURN(a_tensor_desc.shape().At(bias_add_axis), b_tensor_desc.shape().At(0)); *ctx->OutputShape("out", 0) = a_tensor_desc.shape(); - *ctx->OutputIsDynamic("out", 0) = a_tensor_desc.is_dynamic(); + *ctx->MutOutputIsDynamic("out", 0) = a_tensor_desc.is_dynamic(); return Maybe::Ok(); } /*static*/ auto FusedBiasAddGeluOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) @@ -37,7 +37,7 @@ namespace oneflow { } /*static*/ auto FusedBiasAddGeluOp::InferDataType(user_op::InferContext* ctx) -> Maybe { const auto& a_tensor_desc = ctx->InputTensorDesc("a", 0); - *ctx->OutputDType("out", 0) = a_tensor_desc.data_type(); + *ctx->MutOutputDType("out", 0) = a_tensor_desc.data_type(); return Maybe::Ok(); } /*static*/ auto FusedBiasAddGeluOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { @@ -68,7 +68,7 @@ namespace oneflow { CHECK_LT_OR_RETURN(bias_add_axis, a_tensor_desc.shape().NumAxes()); CHECK_EQ_OR_RETURN(a_tensor_desc.shape().At(bias_add_axis), b_tensor_desc.shape().At(0)); *ctx->OutputShape("dx", 0) = a_tensor_desc.shape(); - *ctx->OutputIsDynamic("dx", 0) = a_tensor_desc.is_dynamic(); + *ctx->MutOutputIsDynamic("dx", 0) = a_tensor_desc.is_dynamic(); return Maybe::Ok(); } @@ -78,7 +78,7 @@ namespace oneflow { } /*static*/ auto FusedBiasAddGeluGradOp::InferDataType(user_op::InferContext* ctx) -> Maybe { const auto& a_tensor_desc = ctx->InputTensorDesc("a", 0); - *ctx->OutputDType("dx", 0) = a_tensor_desc.data_type(); + *ctx->MutOutputDType("dx", 0) = a_tensor_desc.data_type(); return Maybe::Ok(); } /*static*/ auto FusedBiasAddGeluGradOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { @@ -153,7 +153,7 @@ REGISTER_USER_OP_GRAD("fused_bias_add_gelu") CHECK_EQ_OR_RETURN(a_tensor_desc.shape().At(bias_add_axis), b_tensor_desc.shape().At(0)); CHECK_EQ_OR_RETURN(a_tensor_desc.shape(), mask_tensor_desc.shape()); *ctx->OutputShape("out", 0) = a_tensor_desc.shape(); - *ctx->OutputIsDynamic("out", 0) = a_tensor_desc.is_dynamic(); + *ctx->MutOutputIsDynamic("out", 0) = a_tensor_desc.is_dynamic(); return Maybe::Ok(); } /*static*/ auto FusedBiasAddMaskScaleOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) @@ -162,7 +162,7 @@ REGISTER_USER_OP_GRAD("fused_bias_add_gelu") } /*static*/ auto FusedBiasAddMaskScaleOp::InferDataType(user_op::InferContext* ctx) -> Maybe { const auto& a_tensor_desc = ctx->InputTensorDesc("a", 0); - *ctx->OutputDType("out", 0) = a_tensor_desc.data_type(); + *ctx->MutOutputDType("out", 0) = a_tensor_desc.data_type(); return Maybe::Ok(); } /*static*/ auto FusedBiasAddMaskScaleOp::ModifyInputArg( diff --git a/oneflow/user/ops/fused_cross_feature_interaction_op.cpp b/oneflow/user/ops/fused_cross_feature_interaction_op.cpp index 5486fc9634a..c7b317cef9d 100644 --- a/oneflow/user/ops/fused_cross_feature_interaction_op.cpp +++ b/oneflow/user/ops/fused_cross_feature_interaction_op.cpp @@ -50,8 +50,8 @@ namespace oneflow { } /* static */ Maybe FusedCrossFeatureInteractionOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("x", 0); - *ctx->OutputDType("matmul_result", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("matmul_result", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } @@ -89,10 +89,10 @@ namespace oneflow { /* static */ Maybe FusedCrossFeatureInteractionV1GradOp::InferDataType( user_op::InferContext* ctx) { - *ctx->OutputDType("dx0", 0) = ctx->InputDType("x", 0); - *ctx->OutputDType("dw", 0) = ctx->InputDType("x", 0); - *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); - *ctx->OutputDType("dbias", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("dx0", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("dw", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("dx", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("dbias", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } @@ -131,10 +131,10 @@ namespace oneflow { /* static */ Maybe FusedCrossFeatureInteractionV2GradOp::InferDataType( user_op::InferContext* ctx) { - *ctx->OutputDType("dx0", 0) = ctx->InputDType("x", 0); - *ctx->OutputDType("dw", 0) = ctx->InputDType("x", 0); - *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); - *ctx->OutputDType("dbias", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("dx0", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("dw", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("dx", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("dbias", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/fused_dot_feature_interaction_op.cpp b/oneflow/user/ops/fused_dot_feature_interaction_op.cpp index 0d99cf8b489..23274a68881 100644 --- a/oneflow/user/ops/fused_dot_feature_interaction_op.cpp +++ b/oneflow/user/ops/fused_dot_feature_interaction_op.cpp @@ -98,7 +98,7 @@ namespace oneflow { CHECK_EQ_OR_RETURN(first_feature_dtype, ctx->InputDType("sparse_feature", 0)) << "get " << first_feature_dtype << " and " << ctx->InputDType("sparse_feature", 0); } - *ctx->OutputDType("out", 0) = first_feature_dtype; + *ctx->MutOutputDType("out", 0) = first_feature_dtype; return Maybe::Ok(); } @@ -139,13 +139,13 @@ namespace oneflow { user_op::InferContext* ctx) { const auto& dy_dtype = ctx->InputDType("dy", 0); for (int64_t i = 0; i < ctx->output_size("features_grad"); ++i) { - *ctx->OutputDType("features_grad", i) = dy_dtype; + *ctx->MutOutputDType("features_grad", i) = dy_dtype; } if (ctx->has_output("output_concat_grad", 0)) { - *ctx->OutputDType("output_concat_grad", 0) = dy_dtype; + *ctx->MutOutputDType("output_concat_grad", 0) = dy_dtype; } if (ctx->has_output("sparse_feature_grad", 0)) { - *ctx->OutputDType("sparse_feature_grad", 0) = dy_dtype; + *ctx->MutOutputDType("sparse_feature_grad", 0) = dy_dtype; } return Maybe::Ok(); } diff --git a/oneflow/user/ops/fused_gru_cell_op.cpp b/oneflow/user/ops/fused_gru_cell_op.cpp index b9b6b7063f1..aefc4ca05e0 100644 --- a/oneflow/user/ops/fused_gru_cell_op.cpp +++ b/oneflow/user/ops/fused_gru_cell_op.cpp @@ -61,8 +61,8 @@ namespace oneflow { /* static */ Maybe FusedGruCellOp::InferDataType(user_op::InferContext* ctx) { const oneflow::DataType& in_types = ctx->InputDType("hx", 0); - *ctx->OutputDType("hy", 0) = in_types; - *ctx->OutputDType("workspace", 0) = in_types; + *ctx->MutOutputDType("hy", 0) = in_types; + *ctx->MutOutputDType("workspace", 0) = in_types; return Maybe::Ok(); } @@ -118,12 +118,14 @@ namespace oneflow { /* static */ Maybe FusedGruCellGradOp ::InferDataType(user_op::InferContext* ctx) { const oneflow::DataType& in_types = ctx->InputDType("grad_hy", 0); - *ctx->OutputDType("grad_input_gates", 0) = in_types; - *ctx->OutputDType("grad_hidden_gates", 0) = in_types; - if (ctx->has_output("grad_hx", 0)) { *ctx->OutputDType("grad_hx", 0) = in_types; } - if (ctx->has_output("grad_input_bias", 0)) { *ctx->OutputDType("grad_input_bias", 0) = in_types; } + *ctx->MutOutputDType("grad_input_gates", 0) = in_types; + *ctx->MutOutputDType("grad_hidden_gates", 0) = in_types; + if (ctx->has_output("grad_hx", 0)) { *ctx->MutOutputDType("grad_hx", 0) = in_types; } + if (ctx->has_output("grad_input_bias", 0)) { + *ctx->MutOutputDType("grad_input_bias", 0) = in_types; + } if (ctx->has_output("grad_hidden_bias", 0)) { - *ctx->OutputDType("grad_hidden_bias", 0) = in_types; + *ctx->MutOutputDType("grad_hidden_bias", 0) = in_types; } return Maybe::Ok(); } diff --git a/oneflow/user/ops/fused_lstm_cell_op.cpp b/oneflow/user/ops/fused_lstm_cell_op.cpp index 5ce8add4f7b..3155dbcb88d 100644 --- a/oneflow/user/ops/fused_lstm_cell_op.cpp +++ b/oneflow/user/ops/fused_lstm_cell_op.cpp @@ -64,9 +64,9 @@ namespace oneflow { /* static */ Maybe FusedLstmCellOp::InferDataType(user_op::InferContext* ctx) { const oneflow::DataType& in_types = ctx->InputDType("cx", 0); - *ctx->OutputDType("hy", 0) = in_types; - *ctx->OutputDType("cy", 0) = in_types; - *ctx->OutputDType("workspace", 0) = in_types; + *ctx->MutOutputDType("hy", 0) = in_types; + *ctx->MutOutputDType("cy", 0) = in_types; + *ctx->MutOutputDType("workspace", 0) = in_types; return Maybe::Ok(); } @@ -118,9 +118,9 @@ namespace oneflow { /* static */ Maybe FusedLstmCellGradOp::InferDataType(user_op::InferContext* ctx) { const oneflow::DataType& in_types = ctx->InputDType("grad_hy", 0); - *ctx->OutputDType("grad_gates", 0) = in_types; - if (ctx->has_output("grad_cx", 0)) { *ctx->OutputDType("grad_cx", 0) = in_types; } - if (ctx->has_output("grad_bias", 0)) { *ctx->OutputDType("grad_bias", 0) = in_types; } + *ctx->MutOutputDType("grad_gates", 0) = in_types; + if (ctx->has_output("grad_cx", 0)) { *ctx->MutOutputDType("grad_cx", 0) = in_types; } + if (ctx->has_output("grad_bias", 0)) { *ctx->MutOutputDType("grad_bias", 0) = in_types; } return Maybe::Ok(); } diff --git a/oneflow/user/ops/fused_relu_dropout_grad_op.cpp b/oneflow/user/ops/fused_relu_dropout_grad_op.cpp index 14101dd16c5..3e99e0ee82e 100644 --- a/oneflow/user/ops/fused_relu_dropout_grad_op.cpp +++ b/oneflow/user/ops/fused_relu_dropout_grad_op.cpp @@ -30,7 +30,7 @@ Maybe InferTensorDesc4FusedReluDropoutGrad(user_op::InferContext* ctx) { } Maybe InferDataType4FusedReluDropoutGrad(user_op::InferContext* ctx) { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + *ctx->MutOutputDType("dx", 0) = ctx->InputDType("dy", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/fused_scale_mask_softmax_dropout_op.cpp b/oneflow/user/ops/fused_scale_mask_softmax_dropout_op.cpp index eabeed57b06..d9e6e5df948 100644 --- a/oneflow/user/ops/fused_scale_mask_softmax_dropout_op.cpp +++ b/oneflow/user/ops/fused_scale_mask_softmax_dropout_op.cpp @@ -28,9 +28,9 @@ namespace oneflow { mask_desc.shape().At(mask_shape.NumAxes() - 1)) << " last dim of x and mask is not equal."; *ctx->OutputShape("y", 0) = x_desc.shape(); - *ctx->OutputIsDynamic("y", 0) = x_desc.is_dynamic(); + *ctx->MutOutputIsDynamic("y", 0) = x_desc.is_dynamic(); *ctx->OutputShape("softmax_y", 0) = x_desc.shape(); - *ctx->OutputIsDynamic("softmax_y", 0) = x_desc.is_dynamic(); + *ctx->MutOutputIsDynamic("softmax_y", 0) = x_desc.is_dynamic(); return Maybe::Ok(); } /*static*/ auto FusedScaleMaskSoftmaxDropoutOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) @@ -42,8 +42,8 @@ namespace oneflow { const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); const user_op::TensorDesc& mask_desc = ctx->InputTensorDesc("mask", 0); CHECK_EQ_OR_RETURN(mask_desc.data_type(), DataType::kBool) << " mask dtype only support bool."; - *ctx->OutputDType("y", 0) = x_desc.data_type(); - *ctx->OutputDType("softmax_y", 0) = x_desc.data_type(); + *ctx->MutOutputDType("y", 0) = x_desc.data_type(); + *ctx->MutOutputDType("softmax_y", 0) = x_desc.data_type(); return Maybe::Ok(); } /*static*/ auto FusedScaleMaskSoftmaxDropoutOp::ModifyInputArg( diff --git a/oneflow/user/ops/fused_scale_mask_softmax_op.cpp b/oneflow/user/ops/fused_scale_mask_softmax_op.cpp index 235e897db47..4add92dce90 100644 --- a/oneflow/user/ops/fused_scale_mask_softmax_op.cpp +++ b/oneflow/user/ops/fused_scale_mask_softmax_op.cpp @@ -28,7 +28,7 @@ namespace oneflow { mask_desc.shape().At(mask_shape.NumAxes() - 1)) << " last dim of x and mask is not equal."; *ctx->OutputShape("y", 0) = x_desc.shape(); - *ctx->OutputIsDynamic("y", 0) = x_desc.is_dynamic(); + *ctx->MutOutputIsDynamic("y", 0) = x_desc.is_dynamic(); return Maybe::Ok(); } /*static*/ auto FusedScaleMaskSoftmaxOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) @@ -39,7 +39,7 @@ namespace oneflow { const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); const user_op::TensorDesc& mask_desc = ctx->InputTensorDesc("mask", 0); CHECK_EQ_OR_RETURN(mask_desc.data_type(), DataType::kBool) << " mask dtype only support bool."; - *ctx->OutputDType("y", 0) = x_desc.data_type(); + *ctx->MutOutputDType("y", 0) = x_desc.data_type(); return Maybe::Ok(); } /*static*/ auto FusedScaleMaskSoftmaxOp::ModifyInputArg( diff --git a/oneflow/user/ops/fused_scale_tril_softmax_mask_scale_op.cpp b/oneflow/user/ops/fused_scale_tril_softmax_mask_scale_op.cpp index 20dead6c8d7..c84b35f9712 100644 --- a/oneflow/user/ops/fused_scale_tril_softmax_mask_scale_op.cpp +++ b/oneflow/user/ops/fused_scale_tril_softmax_mask_scale_op.cpp @@ -21,9 +21,9 @@ namespace oneflow { -> Maybe { const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); *ctx->OutputShape("y", 0) = x_desc.shape(); - *ctx->OutputIsDynamic("y", 0) = x_desc.is_dynamic(); + *ctx->MutOutputIsDynamic("y", 0) = x_desc.is_dynamic(); *ctx->OutputShape("softmax_y", 0) = x_desc.shape(); - *ctx->OutputIsDynamic("softmax_y", 0) = x_desc.is_dynamic(); + *ctx->MutOutputIsDynamic("softmax_y", 0) = x_desc.is_dynamic(); return Maybe::Ok(); } /*static*/ auto FusedTrilScaleSoftmaxMaskScaleOp::InferPhysicalTensorDesc( @@ -33,8 +33,8 @@ namespace oneflow { /*static*/ auto FusedTrilScaleSoftmaxMaskScaleOp::InferDataType(user_op::InferContext* ctx) -> Maybe { const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); - *ctx->OutputDType("y", 0) = x_desc.data_type(); - *ctx->OutputDType("softmax_y", 0) = x_desc.data_type(); + *ctx->MutOutputDType("y", 0) = x_desc.data_type(); + *ctx->MutOutputDType("softmax_y", 0) = x_desc.data_type(); return Maybe::Ok(); } /*static*/ auto FusedTrilScaleSoftmaxMaskScaleOp::ModifyInputArg( diff --git a/oneflow/user/ops/fused_self_attention_query_mul_key_and_value_ops.cpp b/oneflow/user/ops/fused_self_attention_query_mul_key_and_value_ops.cpp index 232a78189c9..e19be385293 100644 --- a/oneflow/user/ops/fused_self_attention_query_mul_key_and_value_ops.cpp +++ b/oneflow/user/ops/fused_self_attention_query_mul_key_and_value_ops.cpp @@ -21,8 +21,8 @@ namespace oneflow { /*static*/ auto FusedSelfAttentionQueryMulKeyAndValueOp::InferDataType(user_op::InferContext* ctx) -> Maybe { const DataType& dtype = ctx->InputDType("hidden_states", 0); - *ctx->OutputDType("query_mul_key", 0) = dtype; - *ctx->OutputDType("value", 0) = dtype; + *ctx->MutOutputDType("query_mul_key", 0) = dtype; + *ctx->MutOutputDType("value", 0) = dtype; return Maybe::Ok(); } /*static*/ auto FusedSelfAttentionQueryMulKeyAndValueOp::InferLogicalTensorDesc( @@ -69,7 +69,7 @@ namespace oneflow { user_op::InferContext* ctx) -> Maybe { const DataType& dtype = ctx->InputDType("query_mul_key_grad", 0); CHECK_EQ_OR_RETURN(ctx->InputDType("value_grad", 0), dtype); - *ctx->OutputDType("hidden_states_grad", 0) = dtype; + *ctx->MutOutputDType("hidden_states_grad", 0) = dtype; return Maybe::Ok(); } /*static*/ auto FusedSelfAttentionQueryMulKeyAndValueGradOp::InferLogicalTensorDesc( diff --git a/oneflow/user/ops/gelu_op.cpp b/oneflow/user/ops/gelu_op.cpp index 39f12592c23..deed561c010 100644 --- a/oneflow/user/ops/gelu_op.cpp +++ b/oneflow/user/ops/gelu_op.cpp @@ -35,7 +35,7 @@ namespace oneflow { return Maybe::Ok(); } /*static*/ auto GeluOp::InferDataType(user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } @@ -68,7 +68,7 @@ namespace oneflow { } /*static*/ auto GeluGradOp::InferDataType(user_op::InferContext* ctx) -> Maybe { CHECK_EQ_OR_RETURN(ctx->InputDType("x", 0), ctx->InputDType("dy", 0)); - *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("dx", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/generate_random_batch_permutation_indices_op.cpp b/oneflow/user/ops/generate_random_batch_permutation_indices_op.cpp index 73b7dcb52eb..1113b008df4 100644 --- a/oneflow/user/ops/generate_random_batch_permutation_indices_op.cpp +++ b/oneflow/user/ops/generate_random_batch_permutation_indices_op.cpp @@ -39,7 +39,7 @@ namespace oneflow { } /*static*/ auto GenerateRandomBatchPermutationIndicesOp::InferDataType(user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("y", 0) = DataType::kInt32; + *ctx->MutOutputDType("y", 0) = DataType::kInt32; return Maybe::Ok(); } diff --git a/oneflow/user/ops/grid_sample_op.cpp b/oneflow/user/ops/grid_sample_op.cpp index febe10c65ad..45ff68c888d 100644 --- a/oneflow/user/ops/grid_sample_op.cpp +++ b/oneflow/user/ops/grid_sample_op.cpp @@ -100,7 +100,7 @@ Maybe GridSampleOp::CheckAttr(const user_op::UserOpDefWrapper& def, return Maybe::Ok(); } /*static*/ auto GridSampleOp::InferDataType(user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("output", 0) = ctx->InputDType("input", 0); + *ctx->MutOutputDType("output", 0) = ctx->InputDType("input", 0); return Maybe::Ok(); } @@ -137,8 +137,8 @@ Maybe GridSampleGradOp::CheckAttr(const user_op::UserOpDefWrapper& def, return Maybe::Ok(); } /*static*/ auto GridSampleGradOp::InferDataType(user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dinput", 0) = ctx->InputDType("input", 0); - *ctx->OutputDType("dgrid", 0) = ctx->InputDType("grid", 0); + *ctx->MutOutputDType("dinput", 0) = ctx->InputDType("input", 0); + *ctx->MutOutputDType("dgrid", 0) = ctx->InputDType("grid", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/hardshrink_op.cpp b/oneflow/user/ops/hardshrink_op.cpp index 21fdae26a17..78753a1a885 100644 --- a/oneflow/user/ops/hardshrink_op.cpp +++ b/oneflow/user/ops/hardshrink_op.cpp @@ -36,7 +36,7 @@ namespace oneflow { } /* static */ Maybe HardShrinkOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } @@ -68,7 +68,7 @@ namespace oneflow { /* static */ Maybe HardShrinkGradOp::InferDataType(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), ctx->InputDType("y", 0)) << "The dtype of y_grad and y must be same."; - *ctx->OutputDType("dx", 0) = ctx->InputDType("y", 0); + *ctx->MutOutputDType("dx", 0) = ctx->InputDType("y", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/hardsigmoid_op.cpp b/oneflow/user/ops/hardsigmoid_op.cpp index 887614425ac..502f582c97a 100644 --- a/oneflow/user/ops/hardsigmoid_op.cpp +++ b/oneflow/user/ops/hardsigmoid_op.cpp @@ -38,7 +38,7 @@ namespace oneflow { } /* static */ Maybe HardsigmoidOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } @@ -69,7 +69,7 @@ namespace oneflow { /* static */ Maybe HardsigmoidGradOp::InferDataType(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(ctx->InputDType("x", 0), ctx->InputDType("dy", 0)); - *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("dx", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/hardswish_op.cpp b/oneflow/user/ops/hardswish_op.cpp index f7dfbc5c870..05c07c61282 100644 --- a/oneflow/user/ops/hardswish_op.cpp +++ b/oneflow/user/ops/hardswish_op.cpp @@ -36,7 +36,7 @@ namespace oneflow { } /* static */ Maybe HardswishOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } @@ -67,7 +67,7 @@ namespace oneflow { /* static */ Maybe HardswishGradOp::InferDataType(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(ctx->InputDType("x", 0), ctx->InputDType("dy", 0)); - *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("dx", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/hardtanh_op.cpp b/oneflow/user/ops/hardtanh_op.cpp index 2d5208c7b0b..94c5c65b021 100644 --- a/oneflow/user/ops/hardtanh_op.cpp +++ b/oneflow/user/ops/hardtanh_op.cpp @@ -41,7 +41,7 @@ namespace oneflow { } /* static */ Maybe HardtanhOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } @@ -75,7 +75,7 @@ namespace oneflow { /* static */ Maybe HardtanhGradOp::InferDataType(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(ctx->InputDType("y", 0), ctx->InputDType("dy", 0)); - *ctx->OutputDType("dx", 0) = ctx->InputDType("y", 0); + *ctx->MutOutputDType("dx", 0) = ctx->InputDType("y", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/hierarchical_parallel_cast_op.cpp b/oneflow/user/ops/hierarchical_parallel_cast_op.cpp index 7ddad5a603f..d8b58606e9d 100644 --- a/oneflow/user/ops/hierarchical_parallel_cast_op.cpp +++ b/oneflow/user/ops/hierarchical_parallel_cast_op.cpp @@ -22,7 +22,7 @@ namespace oneflow { /* static */ Maybe HierarchicalParallelCastOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + *ctx->MutOutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } @@ -51,14 +51,14 @@ namespace oneflow { } /* static */ Maybe HierarchicalParallelCastOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } /* static */ Maybe HierarchicalParallelCastLikeOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + *ctx->MutOutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } @@ -84,7 +84,7 @@ namespace oneflow { } /* static */ Maybe HierarchicalParallelCastLikeOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/identity_op.cpp b/oneflow/user/ops/identity_op.cpp index 538abeb5dde..15471774755 100644 --- a/oneflow/user/ops/identity_op.cpp +++ b/oneflow/user/ops/identity_op.cpp @@ -20,7 +20,7 @@ namespace oneflow { /* static */ Maybe IdentityOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + *ctx->MutOutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } @@ -41,7 +41,7 @@ namespace oneflow { } /* static */ Maybe IdentityOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/image_object_preprocess_ops.cpp b/oneflow/user/ops/image_object_preprocess_ops.cpp index 5fd2cb99f38..76da9b10495 100644 --- a/oneflow/user/ops/image_object_preprocess_ops.cpp +++ b/oneflow/user/ops/image_object_preprocess_ops.cpp @@ -36,7 +36,7 @@ Maybe ImageObjectGetSbp(user_op::SbpContext* ctx) { CHECK_EQ_OR_RETURN(flip_code_desc.shape().elem_cnt(), N); *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + *ctx->MutOutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } @@ -51,7 +51,7 @@ Maybe ImageObjectGetSbp(user_op::SbpContext* ctx) { /* static */ Maybe ImageFlipOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); CHECK_EQ_OR_RETURN(in_desc.data_type(), DataType::kTensorBuffer); - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } @@ -67,7 +67,7 @@ Maybe ImageObjectGetSbp(user_op::SbpContext* ctx) { CHECK_EQ_OR_RETURN(flip_code_desc.shape().elem_cnt(), N); *ctx->OutputShape("out", 0) = ctx->InputShape("bbox", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("bbox", 0); + *ctx->MutOutputIsDynamic("out", 0) = ctx->InputIsDynamic("bbox", 0); return Maybe::Ok(); } @@ -86,7 +86,7 @@ Maybe ImageObjectGetSbp(user_op::SbpContext* ctx) { CHECK_EQ_OR_RETURN(image_size_desc.data_type(), DataType::kInt32); const user_op::TensorDesc& flip_code_desc = ctx->InputTensorDesc("flip_code", 0); CHECK_EQ_OR_RETURN(flip_code_desc.data_type(), DataType::kInt8); - *ctx->OutputDType("out", 0) = ctx->InputDType("bbox", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("bbox", 0); return Maybe::Ok(); } @@ -99,7 +99,7 @@ Maybe ImageObjectGetSbp(user_op::SbpContext* ctx) { CHECK_EQ_OR_RETURN(scale_desc.shape().elem_cnt(), N * 2); *ctx->OutputShape("out", 0) = ctx->InputShape("bbox", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("bbox", 0); + *ctx->MutOutputIsDynamic("out", 0) = ctx->InputIsDynamic("bbox", 0); return Maybe::Ok(); } @@ -116,7 +116,7 @@ Maybe ImageObjectGetSbp(user_op::SbpContext* ctx) { CHECK_EQ_OR_RETURN(bbox_desc.data_type(), DataType::kTensorBuffer); const user_op::TensorDesc& scale_desc = ctx->InputTensorDesc("scale", 0); CHECK_EQ_OR_RETURN(scale_desc.data_type(), DataType::kFloat); - *ctx->OutputDType("out", 0) = ctx->InputDType("bbox", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("bbox", 0); return Maybe::Ok(); } @@ -133,7 +133,7 @@ Maybe ImageObjectGetSbp(user_op::SbpContext* ctx) { CHECK_EQ_OR_RETURN(flip_code_desc.shape().elem_cnt(), N); *ctx->OutputShape("out", 0) = ctx->InputShape("poly", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("poly", 0); + *ctx->MutOutputIsDynamic("out", 0) = ctx->InputIsDynamic("poly", 0); return Maybe::Ok(); } @@ -154,7 +154,7 @@ Maybe ImageObjectGetSbp(user_op::SbpContext* ctx) { CHECK_EQ_OR_RETURN(image_size_desc.data_type(), DataType::kInt32); const user_op::TensorDesc& flip_code_desc = ctx->InputTensorDesc("flip_code", 0); CHECK_EQ_OR_RETURN(flip_code_desc.data_type(), DataType::kInt8); - *ctx->OutputDType("out", 0) = ctx->InputDType("poly", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("poly", 0); return Maybe::Ok(); } @@ -168,7 +168,7 @@ Maybe ImageObjectGetSbp(user_op::SbpContext* ctx) { CHECK_EQ_OR_RETURN(scale_desc.shape().elem_cnt(), N * 2); *ctx->OutputShape("out", 0) = ctx->InputShape("poly", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("poly", 0); + *ctx->MutOutputIsDynamic("out", 0) = ctx->InputIsDynamic("poly", 0); return Maybe::Ok(); } @@ -187,7 +187,7 @@ Maybe ImageObjectGetSbp(user_op::SbpContext* ctx) { CHECK_EQ_OR_RETURN(poly_desc.data_type(), DataType::kTensorBuffer); const user_op::TensorDesc& scale_desc = ctx->InputTensorDesc("scale", 0); CHECK_EQ_OR_RETURN(scale_desc.data_type(), DataType::kFloat); - *ctx->OutputDType("out", 0) = ctx->InputDType("poly", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("poly", 0); return Maybe::Ok(); } @@ -195,7 +195,7 @@ Maybe ImageObjectGetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); CHECK_EQ_OR_RETURN(in_desc.shape().NumAxes(), 1); *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + *ctx->MutOutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } @@ -210,7 +210,7 @@ Maybe ImageObjectGetSbp(user_op::SbpContext* ctx) { /* static */ Maybe ImageNormalizeOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); CHECK_EQ_OR_RETURN(in_desc.data_type(), DataType::kTensorBuffer); - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } @@ -228,7 +228,7 @@ Maybe ImageObjectGetSbp(user_op::SbpContext* ctx) { CHECK_EQ_OR_RETURN(image_size_desc.shape().elem_cnt(), N * 2); *ctx->OutputShape("out", 0) = ctx->InputShape("poly", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("poly", 0); + *ctx->MutOutputIsDynamic("out", 0) = ctx->InputIsDynamic("poly", 0); return Maybe::Ok(); } @@ -249,7 +249,7 @@ Maybe ImageObjectGetSbp(user_op::SbpContext* ctx) { CHECK_EQ_OR_RETURN(poly_index_desc.data_type(), DataType::kTensorBuffer); const user_op::TensorDesc& image_size_desc = ctx->InputTensorDesc("image_size", 0); CHECK_EQ_OR_RETURN(image_size_desc.data_type(), DataType::kInt32); - *ctx->OutputDType("out", 0) = ctx->InputDType("poly", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("poly", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/image_preprocess_ops.cpp b/oneflow/user/ops/image_preprocess_ops.cpp index 00c6d419c8b..e135e0f25db 100644 --- a/oneflow/user/ops/image_preprocess_ops.cpp +++ b/oneflow/user/ops/image_preprocess_ops.cpp @@ -234,7 +234,7 @@ namespace oneflow { /* static */ Maybe ImageRandomCropOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); CHECK_OR_RETURN(in_tensor.data_type() == DataType::kTensorBuffer); - *ctx->OutputDType("out", 0) = in_tensor.data_type(); + *ctx->MutOutputDType("out", 0) = in_tensor.data_type(); return Maybe::Ok(); } diff --git a/oneflow/user/ops/kl_div_op.cpp b/oneflow/user/ops/kl_div_op.cpp index cb58f29764b..636e1680015 100644 --- a/oneflow/user/ops/kl_div_op.cpp +++ b/oneflow/user/ops/kl_div_op.cpp @@ -38,7 +38,7 @@ Maybe KlInferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& target_desc = ctx->InputTensorDesc("target", 0); CHECK_EQ_OR_RETURN(input_desc.data_type(), target_desc.data_type()); - *ctx->OutputDType("out", 0) = ctx->InputDType("input", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("input", 0); return Maybe::Ok(); } @@ -63,7 +63,7 @@ Maybe InferGradDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& target_desc = ctx->InputTensorDesc("target", 0); CHECK_EQ_OR_RETURN(input_desc.data_type(), target_desc.data_type()); - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + *ctx->MutOutputDType("dx", 0) = ctx->InputDType("dy", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/l1_l2_regularize_gradient_op.cpp b/oneflow/user/ops/l1_l2_regularize_gradient_op.cpp index 05affa22404..1909550ead4 100644 --- a/oneflow/user/ops/l1_l2_regularize_gradient_op.cpp +++ b/oneflow/user/ops/l1_l2_regularize_gradient_op.cpp @@ -25,7 +25,7 @@ Maybe InferTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& model_diff = ctx->InputTensorDesc("model_diff", 0); CHECK_EQ_OR_RETURN(model_diff.shape(), model.shape()); *ctx->OutputShape("out", 0) = ctx->InputShape("model", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("model", 0); + *ctx->MutOutputIsDynamic("out", 0) = ctx->InputIsDynamic("model", 0); return Maybe::Ok(); } @@ -57,7 +57,7 @@ Maybe GetSbpSignatures(user_op::SbpContext* ctx) { const user_op::TensorDesc& model = ctx->InputTensorDesc("model", 0); const user_op::TensorDesc& model_diff = ctx->InputTensorDesc("model_diff", 0); CHECK_EQ_OR_RETURN(model_diff.data_type(), model.data_type()); - *ctx->OutputDType("out", 0) = ctx->InputDType("model", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("model", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/l2_normalize_op.cpp b/oneflow/user/ops/l2_normalize_op.cpp index d1723c41c97..59d3cb62100 100644 --- a/oneflow/user/ops/l2_normalize_op.cpp +++ b/oneflow/user/ops/l2_normalize_op.cpp @@ -53,8 +53,8 @@ namespace oneflow { } /* static */ Maybe L2NormalizeOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("square_x_sum", 0) = ctx->InputDType("x", 0); - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("square_x_sum", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("y", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } @@ -103,7 +103,7 @@ namespace oneflow { /* static */ Maybe L2NormalizeGradOp::InferDataType(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(ctx->InputDType("y", 0), ctx->InputDType("dy", 0)); CHECK_EQ_OR_RETURN(ctx->InputDType("y", 0), ctx->InputDType("square_x_sum", 0)); - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + *ctx->MutOutputDType("dx", 0) = ctx->InputDType("dy", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/leaky_relu_op.cpp b/oneflow/user/ops/leaky_relu_op.cpp index 09d8b318c54..b1a26664201 100644 --- a/oneflow/user/ops/leaky_relu_op.cpp +++ b/oneflow/user/ops/leaky_relu_op.cpp @@ -38,7 +38,7 @@ namespace oneflow { } /* static */ Maybe LeakyReluOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("y", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } @@ -74,7 +74,7 @@ namespace oneflow { /* static */ Maybe LeakyReluGradOp::InferDataType(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(ctx->InputDType("x", 0), ctx->InputDType("dy", 0)); - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + *ctx->MutOutputDType("dx", 0) = ctx->InputDType("dy", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/log_softmax_op.cpp b/oneflow/user/ops/log_softmax_op.cpp index d8cffbf7460..55cba6fff21 100644 --- a/oneflow/user/ops/log_softmax_op.cpp +++ b/oneflow/user/ops/log_softmax_op.cpp @@ -39,7 +39,7 @@ namespace oneflow { } /* static */ Maybe LogSoftmaxOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("prob", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("prob", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } @@ -70,7 +70,7 @@ namespace oneflow { /* static */ Maybe LogSoftmaxGradOp::InferDataType(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(ctx->InputDType("prob", 0), ctx->InputDType("dy", 0)); - *ctx->OutputDType("dx", 0) = ctx->InputDType("prob", 0); + *ctx->MutOutputDType("dx", 0) = ctx->InputDType("prob", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/logical_not_op.cpp b/oneflow/user/ops/logical_not_op.cpp index 730b13a9eab..d528ca2d09c 100644 --- a/oneflow/user/ops/logical_not_op.cpp +++ b/oneflow/user/ops/logical_not_op.cpp @@ -21,7 +21,7 @@ namespace oneflow { namespace { Maybe InferDataTypeLogicalNot(user_op::InferContext* ctx) { - *ctx->OutputDType("y", 0) = DataType::kBool; + *ctx->MutOutputDType("y", 0) = DataType::kBool; return Maybe::Ok(); } diff --git a/oneflow/user/ops/masked_fill_op.cpp b/oneflow/user/ops/masked_fill_op.cpp index 327ce994ded..ed1bb688d0d 100644 --- a/oneflow/user/ops/masked_fill_op.cpp +++ b/oneflow/user/ops/masked_fill_op.cpp @@ -29,7 +29,7 @@ Maybe InferMaskedFillTensorDesc(user_op::InferContext* ctx) { Maybe InferMaskedFillDataType(user_op::InferContext* ctx) { const DataType& mask_dtype = ctx->InputDType("mask", 0); CHECK_OR_RETURN(IsIntegralDataType(mask_dtype) || IsBoolDataType(mask_dtype)); - *ctx->OutputDType("out", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/math_binary_broadcast_ops.cpp b/oneflow/user/ops/math_binary_broadcast_ops.cpp index 0c4ef770ac3..cde0c5a4fce 100644 --- a/oneflow/user/ops/math_binary_broadcast_ops.cpp +++ b/oneflow/user/ops/math_binary_broadcast_ops.cpp @@ -36,21 +36,21 @@ Maybe InferTensorDescBinaryBroadcastNormal(user_op::InferContext* ctx) { size_t output_num_axes = std::max(tensor_x.shape().NumAxes(), tensor_y.shape().NumAxes()); if (IsZeroDimTensor(&tensor_x)) { *ctx->OutputShape("z", 0) = ctx->InputShape("y", 0); - *ctx->OutputIsDynamic("z", 0) = ctx->InputIsDynamic("y", 0); + *ctx->MutOutputIsDynamic("z", 0) = ctx->InputIsDynamic("y", 0); } else if (IsZeroDimTensor(&tensor_y)) { *ctx->OutputShape("z", 0) = ctx->InputShape("x", 0); - *ctx->OutputIsDynamic("z", 0) = ctx->InputIsDynamic("x", 0); + *ctx->MutOutputIsDynamic("z", 0) = ctx->InputIsDynamic("x", 0); } else if (IsScalarTensor(&tensor_x)) { *ctx->OutputShape("z", 0) = ctx->InputShape("y", 0); - *ctx->OutputIsDynamic("z", 0) = ctx->InputIsDynamic("y", 0); + *ctx->MutOutputIsDynamic("z", 0) = ctx->InputIsDynamic("y", 0); } else if (IsScalarTensor(&tensor_y)) { *ctx->OutputShape("z", 0) = ctx->InputShape("x", 0); - *ctx->OutputIsDynamic("z", 0) = ctx->InputIsDynamic("x", 0); + *ctx->MutOutputIsDynamic("z", 0) = ctx->InputIsDynamic("x", 0); } else { const auto& x_shape = CreateLeftExtendedShape(ShapeView(tensor_x.shape()), output_num_axes); const auto& y_shape = CreateLeftExtendedShape(ShapeView(tensor_y.shape()), output_num_axes); *ctx->OutputShape("z", 0) = ctx->InputShape("x", 0); - *ctx->OutputIsDynamic("z", 0) = ctx->InputIsDynamic("x", 0); + *ctx->MutOutputIsDynamic("z", 0) = ctx->InputIsDynamic("x", 0); Shape out_shape(x_shape); FOR_RANGE(int64_t, i, 0, x_shape.NumAxes()) { if (x_shape.At(i) != 1 && y_shape.At(i) != 1 && x_shape.At(i) != y_shape.At(i)) { @@ -76,7 +76,7 @@ Maybe InferDataTypeBinaryBroadcastNormal(user_op::InferContext* ctx) { const user_op::TensorDesc& tensor_x = ctx->InputTensorDesc("x", 0); const user_op::TensorDesc& tensor_y = ctx->InputTensorDesc("y", 0); CHECK_EQ_OR_RETURN(tensor_x.data_type(), tensor_y.data_type()); // NOLINT(maybe-need-error-msg) - *ctx->OutputDType("z", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("z", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } @@ -84,7 +84,7 @@ Maybe InferDataTypeBinaryBroadcastLogical(user_op::InferContext* ctx) { const user_op::TensorDesc& tensor_x = ctx->InputTensorDesc("x", 0); const user_op::TensorDesc& tensor_y = ctx->InputTensorDesc("y", 0); CHECK_EQ_OR_RETURN(tensor_x.data_type(), tensor_y.data_type()); // NOLINT(maybe-need-error-msg) - *ctx->OutputDType("z", 0) = DataType::kBool; + *ctx->MutOutputDType("z", 0) = DataType::kBool; return Maybe::Ok(); } diff --git a/oneflow/user/ops/matmul_op.cpp b/oneflow/user/ops/matmul_op.cpp index 9604177ed77..b199a30b081 100644 --- a/oneflow/user/ops/matmul_op.cpp +++ b/oneflow/user/ops/matmul_op.cpp @@ -37,7 +37,7 @@ Maybe InferTensorDesc4Matmul(user_op::InferContext* ctx) { user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); *ctx->OutputShape("out", 0) = ctx->InputShape("a", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("a", 0); + *ctx->MutOutputIsDynamic("out", 0) = ctx->InputIsDynamic("a", 0); int64_t m, n, k; // tensor a (no trans): m*k, tensor b (no trans): k*n if (!transpose_a) { @@ -69,7 +69,7 @@ Maybe InferDataType4Matmul(user_op::InferContext* ctx) { if (ctx->has_input("_add_to_output", 0)) { CHECK_EQ_OR_RETURN(ctx->InputDType("_add_to_output", 0), dtype); } - *ctx->OutputDType("out", 0) = dtype; + *ctx->MutOutputDType("out", 0) = dtype; return Maybe::Ok(); } diff --git a/oneflow/user/ops/matrix_vector_product_op.cpp b/oneflow/user/ops/matrix_vector_product_op.cpp index 91cfba1224b..2814bdd2d13 100644 --- a/oneflow/user/ops/matrix_vector_product_op.cpp +++ b/oneflow/user/ops/matrix_vector_product_op.cpp @@ -34,7 +34,7 @@ Maybe InferDataType4MatrixVectorProduct(user_op::InferContext* ctx) { const DataType& dtype = ctx->InputDType("a", 0); CHECK_EQ_OR_RETURN(ctx->InputDType("b", 0), dtype) << "Matrix A datatype should be equal to Vector B. "; - *ctx->OutputDType("out", 0) = dtype; + *ctx->MutOutputDType("out", 0) = dtype; return Maybe::Ok(); } @@ -64,7 +64,7 @@ Maybe InferTensorDesc4MatrixVectorProductGradB(user_op::InferContext* ctx) Maybe InferDataType4Grad(user_op::InferContext* ctx) { const DataType& dtype = ctx->InputDType("dy", 0); - *ctx->OutputDType("dx", 0) = dtype; + *ctx->MutOutputDType("dx", 0) = dtype; return Maybe::Ok(); } diff --git a/oneflow/user/ops/max_pool_op.cpp b/oneflow/user/ops/max_pool_op.cpp index 8d4d20bc797..53c5573f2a6 100644 --- a/oneflow/user/ops/max_pool_op.cpp +++ b/oneflow/user/ops/max_pool_op.cpp @@ -116,12 +116,12 @@ Maybe BackwardTensorDescInferFn(user_op::InferContext* ctx) { } Maybe FwInferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("y", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } Maybe BwInferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("dx", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } } // namespace diff --git a/oneflow/user/ops/median_op.cpp b/oneflow/user/ops/median_op.cpp index 5ca4689b037..0796c883c6c 100644 --- a/oneflow/user/ops/median_op.cpp +++ b/oneflow/user/ops/median_op.cpp @@ -35,7 +35,7 @@ namespace oneflow { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe MedianOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("output", 0) = ctx->InputDType("input", 0); + *ctx->MutOutputDType("output", 0) = ctx->InputDType("input", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/median_with_indices_op.cpp b/oneflow/user/ops/median_with_indices_op.cpp index d9d0d672735..9f1d0ca5af2 100644 --- a/oneflow/user/ops/median_with_indices_op.cpp +++ b/oneflow/user/ops/median_with_indices_op.cpp @@ -42,8 +42,8 @@ namespace oneflow { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe MedianWithIndicesOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("values", 0) = ctx->InputDType("input", 0); - *ctx->OutputDType("indices", 0) = DataType::kInt64; + *ctx->MutOutputDType("values", 0) = ctx->InputDType("input", 0); + *ctx->MutOutputDType("indices", 0) = DataType::kInt64; return Maybe::Ok(); } diff --git a/oneflow/user/ops/min_max_observer_op.cpp b/oneflow/user/ops/min_max_observer_op.cpp index 3d7f186c378..b5f83abe8d2 100644 --- a/oneflow/user/ops/min_max_observer_op.cpp +++ b/oneflow/user/ops/min_max_observer_op.cpp @@ -70,8 +70,8 @@ namespace oneflow { } /* static */ Maybe MinMaxObserverOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("scale", 0) = ctx->InputDType("in", 0); - *ctx->OutputDType("zero_point", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("scale", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("zero_point", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/mish_op.cpp b/oneflow/user/ops/mish_op.cpp index bee4ebb18a8..88902139a49 100644 --- a/oneflow/user/ops/mish_op.cpp +++ b/oneflow/user/ops/mish_op.cpp @@ -36,7 +36,7 @@ namespace oneflow { } /* static */ Maybe MishOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } @@ -67,7 +67,7 @@ namespace oneflow { /* static */ Maybe MishGradOp::InferDataType(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), ctx->InputDType("x", 0)); - *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("dx", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/model_update_ops.cpp b/oneflow/user/ops/model_update_ops.cpp index 0bcaf045247..eba343916bf 100644 --- a/oneflow/user/ops/model_update_ops.cpp +++ b/oneflow/user/ops/model_update_ops.cpp @@ -766,7 +766,7 @@ Maybe InferLarsUpdateDataType(user_op::InferContext* ctx) { } /* static */ Maybe AdamBiasCorrectionFactorOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = DataType::kFloat; + *ctx->MutOutputDType("out", 0) = DataType::kFloat; return Maybe::Ok(); } diff --git a/oneflow/user/ops/moving_average_min_max_observer_op.cpp b/oneflow/user/ops/moving_average_min_max_observer_op.cpp index 434865f2d59..0a44d018c55 100644 --- a/oneflow/user/ops/moving_average_min_max_observer_op.cpp +++ b/oneflow/user/ops/moving_average_min_max_observer_op.cpp @@ -87,8 +87,8 @@ namespace oneflow { } /* static */ Maybe MovingAverageMinMaxObserverOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("scale", 0) = ctx->InputDType("in", 0); - *ctx->OutputDType("zero_point", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("scale", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("zero_point", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/multi_reduce_ops.cpp b/oneflow/user/ops/multi_reduce_ops.cpp index 58ceca4ff10..8ad90716e8a 100644 --- a/oneflow/user/ops/multi_reduce_ops.cpp +++ b/oneflow/user/ops/multi_reduce_ops.cpp @@ -33,7 +33,7 @@ Maybe InferMultiReduceOpDataType(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(ctx->InputDType("x", i), x_0_dtype) << ctx->op_name() << ": the " << i << " th input has the different data type with others"; } - *ctx->OutputDType("y", 0) = x_0_dtype; + *ctx->MutOutputDType("y", 0) = x_0_dtype; return Maybe::Ok(); } diff --git a/oneflow/user/ops/narrow_op.cpp b/oneflow/user/ops/narrow_op.cpp index a8569c6784e..6c4814b590d 100644 --- a/oneflow/user/ops/narrow_op.cpp +++ b/oneflow/user/ops/narrow_op.cpp @@ -131,7 +131,7 @@ namespace oneflow { } /* static */ Maybe NarrowGradOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + *ctx->MutOutputDType("dx", 0) = ctx->InputDType("dy", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/nccl_logical_2d_sbp_ops.cpp b/oneflow/user/ops/nccl_logical_2d_sbp_ops.cpp index f8bf37f2771..f6d53877961 100644 --- a/oneflow/user/ops/nccl_logical_2d_sbp_ops.cpp +++ b/oneflow/user/ops/nccl_logical_2d_sbp_ops.cpp @@ -24,7 +24,7 @@ namespace oneflow { /* static */ Maybe _ncclLogical_2DSameDim0AllReduceOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + *ctx->MutOutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } @@ -54,7 +54,7 @@ namespace oneflow { /* static */ Maybe _ncclLogical_2DSameDim0AllReduceOp::InferDataType( user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } @@ -66,7 +66,7 @@ namespace oneflow { /* static */ Maybe _ncclLogical_2DSameDim1AllReduceOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + *ctx->MutOutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } @@ -96,7 +96,7 @@ namespace oneflow { /* static */ Maybe _ncclLogical_2DSameDim1AllReduceOp::InferDataType( user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } @@ -108,7 +108,7 @@ namespace oneflow { /* static */ Maybe _ncclLogical_2DSameDim0AllGatherOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + *ctx->MutOutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } @@ -139,7 +139,7 @@ namespace oneflow { /* static */ Maybe _ncclLogical_2DSameDim0AllGatherOp::InferDataType( user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } @@ -151,7 +151,7 @@ namespace oneflow { /* static */ Maybe _ncclLogical_2DSameDim0AllGatherNoncontinuousOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + *ctx->MutOutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } @@ -183,7 +183,7 @@ namespace oneflow { /* static */ Maybe _ncclLogical_2DSameDim0AllGatherNoncontinuousOp::InferDataType( user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } @@ -196,7 +196,7 @@ _ncclLogical_2DSameDim0AllGatherNoncontinuousOp::InferDeviceAndStream( /* static */ Maybe _ncclLogical_2DSameDim0All2allOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + *ctx->MutOutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } @@ -226,7 +226,7 @@ _ncclLogical_2DSameDim0AllGatherNoncontinuousOp::InferDeviceAndStream( /* static */ Maybe _ncclLogical_2DSameDim0All2allOp::InferDataType( user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/nccl_logical_ops.cpp b/oneflow/user/ops/nccl_logical_ops.cpp index 5f157516389..7cda3fe908e 100644 --- a/oneflow/user/ops/nccl_logical_ops.cpp +++ b/oneflow/user/ops/nccl_logical_ops.cpp @@ -24,7 +24,7 @@ namespace oneflow { /* static */ Maybe _ncclLogicalAllReduceOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + *ctx->MutOutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } @@ -51,7 +51,7 @@ namespace oneflow { } /* static */ Maybe _ncclLogicalAllReduceOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } @@ -63,7 +63,7 @@ namespace oneflow { /* static */ Maybe _ncclLogicalReduceScatterOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + *ctx->MutOutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } @@ -92,7 +92,7 @@ namespace oneflow { } /* static */ Maybe _ncclLogicalReduceScatterOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } @@ -104,7 +104,7 @@ namespace oneflow { /* static */ Maybe _ncclLogicalAllGatherOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + *ctx->MutOutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } @@ -132,7 +132,7 @@ namespace oneflow { } /* static */ Maybe _ncclLogicalAllGatherOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } @@ -144,7 +144,7 @@ namespace oneflow { /* static */ Maybe _ncclLogicalAllGatherNoncontinuousOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + *ctx->MutOutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } @@ -174,7 +174,7 @@ namespace oneflow { /* static */ Maybe _ncclLogicalAllGatherNoncontinuousOp::InferDataType( user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } @@ -186,7 +186,7 @@ namespace oneflow { /* static */ Maybe _ncclLogicalReduceScatterNoncontinuousOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + *ctx->MutOutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } @@ -220,7 +220,7 @@ namespace oneflow { /* static */ Maybe _ncclLogicalReduceScatterNoncontinuousOp::InferDataType( user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } @@ -231,7 +231,7 @@ namespace oneflow { /* static */ Maybe _ncclLogicalS2sOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + *ctx->MutOutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } @@ -258,7 +258,7 @@ namespace oneflow { } /* static */ Maybe _ncclLogicalS2sOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } @@ -270,7 +270,7 @@ namespace oneflow { /* static */ Maybe _ncclLogicalSendRecvOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + *ctx->MutOutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } @@ -291,7 +291,7 @@ namespace oneflow { } /* static */ Maybe _ncclLogicalSendRecvOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/nd_index_slice_ops.cpp b/oneflow/user/ops/nd_index_slice_ops.cpp index 2fa17d2d390..792331acfce 100644 --- a/oneflow/user/ops/nd_index_slice_ops.cpp +++ b/oneflow/user/ops/nd_index_slice_ops.cpp @@ -47,7 +47,7 @@ Maybe InferScatterNdTensorDesc(user_op::InferContext* ctx) { } Maybe InferScatterNdDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("updates", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("updates", 0); return Maybe::Ok(); } @@ -61,7 +61,7 @@ Maybe InferScatterNdLikeTensorDesc(user_op::InferContext* ctx) { } Maybe InferScatterNdLikeDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("updates", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("updates", 0); return Maybe::Ok(); } @@ -75,7 +75,7 @@ Maybe InferTensorScatterNdOptTensorDesc(user_op::InferContext* ctx) { } Maybe InferTensorScatterNdOptDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("params", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("params", 0); return Maybe::Ok(); } @@ -168,7 +168,7 @@ Maybe GetTensorScatterNdOptSbpSignatures(user_op::SbpContext* ctx) { } /* static */ Maybe GatherNdOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("params", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("params", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/nll_op.cpp b/oneflow/user/ops/nll_op.cpp index 1afffc2c16b..b7a56773e9f 100644 --- a/oneflow/user/ops/nll_op.cpp +++ b/oneflow/user/ops/nll_op.cpp @@ -29,8 +29,8 @@ namespace oneflow { << input_dtype << ", but got " << weight_dtype; } - *ctx->OutputDType("output", 0) = input_dtype; - *ctx->OutputDType("out_weight", 0) = input_dtype; + *ctx->MutOutputDType("output", 0) = input_dtype; + *ctx->MutOutputDType("out_weight", 0) = input_dtype; return Maybe::Ok(); } @@ -126,7 +126,7 @@ namespace oneflow { << ctx->InputDType("weight", 0); } - *ctx->OutputDType("in_grad", 0) = input_dtype; + *ctx->MutOutputDType("in_grad", 0) = input_dtype; return Maybe::Ok(); } diff --git a/oneflow/user/ops/nms_op.cpp b/oneflow/user/ops/nms_op.cpp index 1d9c0e29537..b52377ed003 100644 --- a/oneflow/user/ops/nms_op.cpp +++ b/oneflow/user/ops/nms_op.cpp @@ -26,7 +26,7 @@ Maybe InferNmsTensorDesc(user_op::InferContext* ctx) { } Maybe InferNmsDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = DataType::kInt8; + *ctx->MutOutputDType("out", 0) = DataType::kInt8; return Maybe::Ok(); } diff --git a/oneflow/user/ops/nvtx_range_op.cpp b/oneflow/user/ops/nvtx_range_op.cpp index 0f2bd54b2e6..1a50cd9cff6 100644 --- a/oneflow/user/ops/nvtx_range_op.cpp +++ b/oneflow/user/ops/nvtx_range_op.cpp @@ -23,7 +23,7 @@ namespace oneflow { /* static */ Maybe NvtxStartOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + *ctx->MutOutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } @@ -44,13 +44,13 @@ namespace oneflow { } /* static */ Maybe NvtxStartOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } /* static */ Maybe NvtxEndOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + *ctx->MutOutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } @@ -71,7 +71,7 @@ namespace oneflow { } /* static */ Maybe NvtxEndOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/ofrecord_image_classification_reader_op.cpp b/oneflow/user/ops/ofrecord_image_classification_reader_op.cpp index 800ec2d27e0..9b683de5a1f 100644 --- a/oneflow/user/ops/ofrecord_image_classification_reader_op.cpp +++ b/oneflow/user/ops/ofrecord_image_classification_reader_op.cpp @@ -68,8 +68,8 @@ namespace oneflow { /* static */ Maybe OfrecordImageClassificationReaderOp::InferDataType( user_op::InferContext* ctx) { - *ctx->OutputDType("image", 0) = DataType::kTensorBuffer; - *ctx->OutputDType("label", 0) = DataType::kTensorBuffer; + *ctx->MutOutputDType("image", 0) = DataType::kTensorBuffer; + *ctx->MutOutputDType("label", 0) = DataType::kTensorBuffer; return Maybe::Ok(); } diff --git a/oneflow/user/ops/ofrecord_reader_op.cpp b/oneflow/user/ops/ofrecord_reader_op.cpp index 6d40f5f92bd..a43a08015a7 100644 --- a/oneflow/user/ops/ofrecord_reader_op.cpp +++ b/oneflow/user/ops/ofrecord_reader_op.cpp @@ -64,7 +64,7 @@ namespace oneflow { } /* static */ Maybe OFRecordReaderOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = DataType::kOFRecord; + *ctx->MutOutputDType("out", 0) = DataType::kOFRecord; return Maybe::Ok(); } diff --git a/oneflow/user/ops/one_embedding_ops.cpp b/oneflow/user/ops/one_embedding_ops.cpp index 49a3f0b8fe0..4777c9d4e91 100644 --- a/oneflow/user/ops/one_embedding_ops.cpp +++ b/oneflow/user/ops/one_embedding_ops.cpp @@ -68,7 +68,7 @@ namespace oneflow { } /* static */ Maybe EmbeddingLookupPlaceholderOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("embeddings", 0) = ctx->InputDType("shadow", 0); + *ctx->MutOutputDType("embeddings", 0) = ctx->InputDType("shadow", 0); return Maybe::Ok(); } @@ -135,7 +135,7 @@ REGISTER_USER_OP_GRAD("embedding_lookup_placeholder") } /* static */ Maybe EmbeddingPrefetchOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("context", 0) = ctx->InputDType("num_unique_ids", 0); + *ctx->MutOutputDType("context", 0) = ctx->InputDType("num_unique_ids", 0); return Maybe::Ok(); } @@ -203,9 +203,9 @@ REGISTER_USER_OP_GRAD("embedding_lookup_placeholder") } /* static */ Maybe EmbeddingLookupOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("unique_values", 0) = ctx->Attr("dtype"); + *ctx->MutOutputDType("unique_values", 0) = ctx->Attr("dtype"); if (ctx->has_output("embeddings", 0)) { - *ctx->OutputDType("embeddings", 0) = ctx->Attr("embeddings_dtype"); + *ctx->MutOutputDType("embeddings", 0) = ctx->Attr("embeddings_dtype"); } return Maybe::Ok(); } @@ -333,7 +333,7 @@ Maybe GetEmbeddingUpdateSbp(user_op::SbpContext* ctx) { /* static */ Maybe SgdEmbeddingUpdateOp::InferDataType(user_op::InferContext* ctx) { JUST(CheckDataType(ctx)); - *ctx->OutputDType("updated_unique_embeddings", 0) = ctx->InputDType("unique_embeddings", 0); + *ctx->MutOutputDType("updated_unique_embeddings", 0) = ctx->InputDType("unique_embeddings", 0); return Maybe::Ok(); } @@ -362,7 +362,7 @@ Maybe GetEmbeddingUpdateSbp(user_op::SbpContext* ctx) { /* static */ Maybe MomentumEmbeddingUpdateOp::InferDataType(user_op::InferContext* ctx) { JUST(CheckDataType(ctx)); - *ctx->OutputDType("updated_unique_embeddings", 0) = ctx->InputDType("unique_embeddings", 0); + *ctx->MutOutputDType("updated_unique_embeddings", 0) = ctx->InputDType("unique_embeddings", 0); return Maybe::Ok(); } @@ -389,7 +389,7 @@ Maybe GetEmbeddingUpdateSbp(user_op::SbpContext* ctx) { /* static */ Maybe AdamEmbeddingUpdateOp::InferDataType(user_op::InferContext* ctx) { JUST(CheckDataType(ctx)); - *ctx->OutputDType("updated_unique_embeddings", 0) = ctx->InputDType("unique_embeddings", 0); + *ctx->MutOutputDType("updated_unique_embeddings", 0) = ctx->InputDType("unique_embeddings", 0); return Maybe::Ok(); } @@ -418,7 +418,7 @@ Maybe GetEmbeddingUpdateSbp(user_op::SbpContext* ctx) { /* static */ Maybe AdagradEmbeddingUpdateOp::InferDataType(user_op::InferContext* ctx) { JUST(CheckDataType(ctx)); - *ctx->OutputDType("updated_unique_embeddings", 0) = ctx->InputDType("unique_embeddings", 0); + *ctx->MutOutputDType("updated_unique_embeddings", 0) = ctx->InputDType("unique_embeddings", 0); return Maybe::Ok(); } @@ -445,7 +445,7 @@ Maybe GetEmbeddingUpdateSbp(user_op::SbpContext* ctx) { /* static */ Maybe FtrlEmbeddingUpdateOp::InferDataType(user_op::InferContext* ctx) { JUST(CheckDataType(ctx)); - *ctx->OutputDType("updated_unique_embeddings", 0) = ctx->InputDType("unique_embeddings", 0); + *ctx->MutOutputDType("updated_unique_embeddings", 0) = ctx->InputDType("unique_embeddings", 0); return Maybe::Ok(); } @@ -477,14 +477,14 @@ Maybe GetEmbeddingUpdateSbp(user_op::SbpContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe IdShuffleCopyOutOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out_num_unique_matrix", 0) = ctx->InputDType("num_unique_matrix", 0); - *ctx->OutputDType("out_inverse_unique_partition_indices", 0) = + *ctx->MutOutputDType("out_num_unique_matrix", 0) = ctx->InputDType("num_unique_matrix", 0); + *ctx->MutOutputDType("out_inverse_unique_partition_indices", 0) = ctx->InputDType("inverse_unique_partition_indices", 0); - *ctx->OutputDType("out_cur_rank_num_unique", 0) = ctx->InputDType("cur_rank_num_unique", 0); - *ctx->OutputDType("out_cur_rank_unique_ids", 0) = ctx->InputDType("cur_rank_unique_ids", 0); - *ctx->OutputDType("out_cur_rank_unique_table_ids", 0) = + *ctx->MutOutputDType("out_cur_rank_num_unique", 0) = ctx->InputDType("cur_rank_num_unique", 0); + *ctx->MutOutputDType("out_cur_rank_unique_ids", 0) = ctx->InputDType("cur_rank_unique_ids", 0); + *ctx->MutOutputDType("out_cur_rank_unique_table_ids", 0) = ctx->InputDType("cur_rank_unique_table_ids", 0); - *ctx->OutputDType("out_cur_rank_inverse_indices", 0) = + *ctx->MutOutputDType("out_cur_rank_inverse_indices", 0) = ctx->InputDType("cur_rank_inverse_indices", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/onerec_reader_op.cpp b/oneflow/user/ops/onerec_reader_op.cpp index 7a53d7d584a..95b34f8dbf4 100644 --- a/oneflow/user/ops/onerec_reader_op.cpp +++ b/oneflow/user/ops/onerec_reader_op.cpp @@ -26,7 +26,7 @@ namespace oneflow { } /*static*/ Maybe OneRecReaderOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = DataType::kTensorBuffer; + *ctx->MutOutputDType("out", 0) = DataType::kTensorBuffer; return Maybe::Ok(); } diff --git a/oneflow/user/ops/ones_like_op.cpp b/oneflow/user/ops/ones_like_op.cpp index c64eefc2a0f..fd04c07a028 100644 --- a/oneflow/user/ops/ones_like_op.cpp +++ b/oneflow/user/ops/ones_like_op.cpp @@ -41,7 +41,7 @@ namespace oneflow { return OnesLikeOp::InferLogicalTensorDesc(ctx); } /*static*/ Maybe OnesLikeOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("like", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("like", 0); return Maybe::Ok(); } /*static*/ Maybe OnesLikeOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { diff --git a/oneflow/user/ops/p2p_comm_op.cpp b/oneflow/user/ops/p2p_comm_op.cpp index 0c6998bdb87..2b76dbdcd1f 100644 --- a/oneflow/user/ops/p2p_comm_op.cpp +++ b/oneflow/user/ops/p2p_comm_op.cpp @@ -55,7 +55,7 @@ Maybe> GetRecvOutputDeivce(user_op::DeviceAndStreamInferContext* return SendOp::InferLogicalTensorDesc(ctx); } /*static*/ Maybe RecvOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->Attr("dtype"); + *ctx->MutOutputDType("out", 0) = ctx->Attr("dtype"); return Maybe::Ok(); } /*static*/ Maybe> RecvOp::InferDeviceAndStream( diff --git a/oneflow/user/ops/pack_op.cpp b/oneflow/user/ops/pack_op.cpp index b5ae5c75a74..828192e77e2 100644 --- a/oneflow/user/ops/pack_op.cpp +++ b/oneflow/user/ops/pack_op.cpp @@ -51,7 +51,7 @@ namespace oneflow { return PackOp::InferLogicalTensorDesc(ctx); } /*static*/ Maybe PackOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } /*static*/ Maybe PackOp::InferOutputBlobTimeShape( diff --git a/oneflow/user/ops/pad_op.cpp b/oneflow/user/ops/pad_op.cpp index d1d020ed355..74ba9d17c6b 100644 --- a/oneflow/user/ops/pad_op.cpp +++ b/oneflow/user/ops/pad_op.cpp @@ -47,7 +47,7 @@ namespace oneflow { return PadOp::InferLogicalTensorDesc(ctx); } /*static*/ Maybe PadOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("y", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/padding_ops.cpp b/oneflow/user/ops/padding_ops.cpp index 41ef1da54ea..34f6b737829 100644 --- a/oneflow/user/ops/padding_ops.cpp +++ b/oneflow/user/ops/padding_ops.cpp @@ -81,7 +81,7 @@ Maybe GetOpGradSbpSignature(user_op::SbpContext* ctx) { return ReflectionPad2DOp::InferLogicalTensorDesc(ctx); } /*static*/ Maybe ReflectionPad2DOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("y", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } /*static*/ Maybe ReflectionPad2DOp::ModifyInputArg( @@ -120,7 +120,7 @@ Maybe GetOpGradSbpSignature(user_op::SbpContext* ctx) { return ReflectionPad2DGradOp::InferLogicalTensorDesc(ctx); } /*static*/ Maybe ReflectionPad2DGradOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + *ctx->MutOutputDType("dx", 0) = ctx->InputDType("dy", 0); return Maybe::Ok(); } @@ -169,7 +169,7 @@ REGISTER_USER_OP_GRAD("reflection_pad2d") return ReplicationPad2DOp::InferLogicalTensorDesc(ctx); } /*static*/ Maybe ReplicationPad2DOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("y", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } /*static*/ Maybe ReplicationPad2DOp::ModifyInputArg( @@ -208,7 +208,7 @@ REGISTER_USER_OP_GRAD("reflection_pad2d") return ReplicationPad2DGradOp::InferLogicalTensorDesc(ctx); } /*static*/ Maybe ReplicationPad2DGradOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + *ctx->MutOutputDType("dx", 0) = ctx->InputDType("dy", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/parallel_cast_op.cpp b/oneflow/user/ops/parallel_cast_op.cpp index 9d25b9504de..1ed4a96dbbe 100644 --- a/oneflow/user/ops/parallel_cast_op.cpp +++ b/oneflow/user/ops/parallel_cast_op.cpp @@ -24,14 +24,14 @@ namespace oneflow { } /*static*/ Maybe ParallelCastOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + *ctx->MutOutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } /*static*/ Maybe ParallelCastOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return ParallelCastOp::InferLogicalTensorDesc(ctx); } /*static*/ Maybe ParallelCastOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } /*static*/ Maybe ParallelCastOp::InferSbpSignature(user_op::InferSbpSignatureFnContext* ctx) { diff --git a/oneflow/user/ops/partial_fc_sample_op.cpp b/oneflow/user/ops/partial_fc_sample_op.cpp index 1798e91fe6d..e398e3b15d6 100644 --- a/oneflow/user/ops/partial_fc_sample_op.cpp +++ b/oneflow/user/ops/partial_fc_sample_op.cpp @@ -68,9 +68,9 @@ namespace oneflow { return Maybe::Ok(); } /*static*/ Maybe DistributedPartialFcSampleOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("mapped_label", 0) = ctx->InputDType("label", 0); - *ctx->OutputDType("sampled_weight", 0) = ctx->InputDType("weight", 0); - *ctx->OutputDType("sampled_label", 0) = ctx->InputDType("label", 0); + *ctx->MutOutputDType("mapped_label", 0) = ctx->InputDType("label", 0); + *ctx->MutOutputDType("sampled_weight", 0) = ctx->InputDType("weight", 0); + *ctx->MutOutputDType("sampled_label", 0) = ctx->InputDType("label", 0); return Maybe::Ok(); } /*static*/ Maybe DistributedPartialFcSampleOp::ModifyInputArg( @@ -113,18 +113,18 @@ namespace oneflow { user_op::InferContext* ctx) { *ctx->OutputShape("boxing_disabled_sampled_weight_diff", 0) = ctx->InputShape("sampled_weight_diff", 0); - *ctx->OutputIsDynamic("boxing_disabled_sampled_weight_diff", 0) = + *ctx->MutOutputIsDynamic("boxing_disabled_sampled_weight_diff", 0) = ctx->InputIsDynamic("sampled_weight_diff", 0); *ctx->OutputShape("boxing_disabled_sampled_label", 0) = ctx->InputShape("sampled_label", 0); - *ctx->OutputIsDynamic("boxing_disabled_sampled_label", 0) = + *ctx->MutOutputIsDynamic("boxing_disabled_sampled_label", 0) = ctx->InputIsDynamic("sampled_label", 0); return Maybe::Ok(); } /*static*/ Maybe DistributedPartialFcSampleDisableBoxingOp::InferDataType( user_op::InferContext* ctx) { - *ctx->OutputDType("boxing_disabled_sampled_weight_diff", 0) = + *ctx->MutOutputDType("boxing_disabled_sampled_weight_diff", 0) = ctx->InputDType("sampled_weight_diff", 0); - *ctx->OutputDType("boxing_disabled_sampled_label", 0) = ctx->InputDType("sampled_label", 0); + *ctx->MutOutputDType("boxing_disabled_sampled_label", 0) = ctx->InputDType("sampled_label", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/prelu_op.cpp b/oneflow/user/ops/prelu_op.cpp index 6cd352ba5ba..d9bb036a96c 100644 --- a/oneflow/user/ops/prelu_op.cpp +++ b/oneflow/user/ops/prelu_op.cpp @@ -50,7 +50,7 @@ namespace oneflow { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe PreluOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("y", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } @@ -105,8 +105,8 @@ namespace oneflow { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe PreluGradOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); - *ctx->OutputDType("alpha_diff", 0) = ctx->InputDType("alpha", 0); + *ctx->MutOutputDType("dx", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("alpha_diff", 0) = ctx->InputDType("alpha", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/quantization_op.cpp b/oneflow/user/ops/quantization_op.cpp index 2396a1a1685..ad129ad9c27 100644 --- a/oneflow/user/ops/quantization_op.cpp +++ b/oneflow/user/ops/quantization_op.cpp @@ -75,7 +75,7 @@ namespace oneflow { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe QuantizationOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } /*static*/ Maybe QuantizationOp::ModifyInputArg( diff --git a/oneflow/user/ops/randperm_op.cpp b/oneflow/user/ops/randperm_op.cpp index 956902154ae..9343ec741f8 100644 --- a/oneflow/user/ops/randperm_op.cpp +++ b/oneflow/user/ops/randperm_op.cpp @@ -51,7 +51,7 @@ namespace oneflow { } /*static*/ Maybe RandpermOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = DataType::kInt32; + *ctx->MutOutputDType("out", 0) = DataType::kInt32; return Maybe::Ok(); } diff --git a/oneflow/user/ops/reduce_like_ops.cpp b/oneflow/user/ops/reduce_like_ops.cpp index 64d5db36a67..381c0c52ccc 100644 --- a/oneflow/user/ops/reduce_like_ops.cpp +++ b/oneflow/user/ops/reduce_like_ops.cpp @@ -93,7 +93,7 @@ namespace oneflow { const user_op::TensorDesc& like_tensor = ctx->InputTensorDesc("like", 0); CHECK_EQ_OR_RETURN(x_tensor.data_type(), like_tensor.data_type()) << Error::TypeError() << "Tensors x and like must have the same type"; - *ctx->OutputDType("y", 0) = like_tensor.data_type(); + *ctx->MutOutputDType("y", 0) = like_tensor.data_type(); return Maybe::Ok(); } /*static*/ Maybe ReduceSumLikeOp::ModifyInputArg( diff --git a/oneflow/user/ops/reduce_ops.cpp b/oneflow/user/ops/reduce_ops.cpp index 5ac0a70038c..e853b4d79fc 100644 --- a/oneflow/user/ops/reduce_ops.cpp +++ b/oneflow/user/ops/reduce_ops.cpp @@ -43,12 +43,12 @@ Maybe InferTensorDescFn(user_op::InferContext* ctx) { } Maybe InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("output_tensor", 0) = ctx->InputDType("input_tensor", 0); + *ctx->MutOutputDType("output_tensor", 0) = ctx->InputDType("input_tensor", 0); return Maybe::Ok(); } Maybe InferLogicalDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("output_tensor", 0) = DataType::kBool; + *ctx->MutOutputDType("output_tensor", 0) = DataType::kBool; return Maybe::Ok(); } diff --git a/oneflow/user/ops/relu_op.cpp b/oneflow/user/ops/relu_op.cpp index 38e4f58328a..b05e831aeea 100644 --- a/oneflow/user/ops/relu_op.cpp +++ b/oneflow/user/ops/relu_op.cpp @@ -35,7 +35,7 @@ namespace oneflow { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe ReluOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("y", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } @@ -66,7 +66,7 @@ namespace oneflow { const DataType& data_type = ctx->InputDType("y", 0); CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), data_type) << Error::TypeError() << "Tensors dy and y must have the same type"; - *ctx->OutputDType("dx", 0) = data_type; + *ctx->MutOutputDType("dx", 0) = data_type; return Maybe::Ok(); } diff --git a/oneflow/user/ops/repeat_interleave_op.cpp b/oneflow/user/ops/repeat_interleave_op.cpp index ec77a9efe3b..22742f9cb2f 100644 --- a/oneflow/user/ops/repeat_interleave_op.cpp +++ b/oneflow/user/ops/repeat_interleave_op.cpp @@ -44,7 +44,7 @@ namespace oneflow { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe Repeat_InterLeaveOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/repeat_op.cpp b/oneflow/user/ops/repeat_op.cpp index 60b281854dc..babe9f38bfd 100644 --- a/oneflow/user/ops/repeat_op.cpp +++ b/oneflow/user/ops/repeat_op.cpp @@ -32,14 +32,14 @@ namespace oneflow { } /*static*/ Maybe RepeatOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + *ctx->MutOutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); return Maybe::Ok(); } /*static*/ Maybe RepeatOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe RepeatOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } /*static*/ Maybe RepeatOp::InferOutputBlobTimeShape( diff --git a/oneflow/user/ops/reshape_like_op.cpp b/oneflow/user/ops/reshape_like_op.cpp index 7b11d6de6f0..b45455c8ea9 100644 --- a/oneflow/user/ops/reshape_like_op.cpp +++ b/oneflow/user/ops/reshape_like_op.cpp @@ -51,7 +51,7 @@ namespace oneflow { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe ReshapeLikeOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } /*static*/ Maybe ReshapeLikeOp::ModifyInputArg( diff --git a/oneflow/user/ops/reshape_op.cpp b/oneflow/user/ops/reshape_op.cpp index 4876e1f833c..77b4fc35c3d 100644 --- a/oneflow/user/ops/reshape_op.cpp +++ b/oneflow/user/ops/reshape_op.cpp @@ -124,7 +124,7 @@ namespace oneflow { } /*static*/ Maybe ReshapeOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/roc_auc_score_op.cpp b/oneflow/user/ops/roc_auc_score_op.cpp index 9a7e68ed524..19c428dae90 100644 --- a/oneflow/user/ops/roc_auc_score_op.cpp +++ b/oneflow/user/ops/roc_auc_score_op.cpp @@ -38,7 +38,7 @@ namespace oneflow { } /* static */ Maybe RocAucScoreOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = DataType::kFloat; + *ctx->MutOutputDType("out", 0) = DataType::kFloat; const user_op::TensorDesc& label = ctx->InputTensorDesc("label", 0); CHECK_OR_RETURN(IsFloatingDataType(label.data_type()) || IsIntegralDataType(label.data_type())) << "Input `label` data type " << DataType_Name(label.data_type()) << " is not supported."; diff --git a/oneflow/user/ops/roi_align_op.cpp b/oneflow/user/ops/roi_align_op.cpp index eeb77b9f4ea..9498e077a64 100644 --- a/oneflow/user/ops/roi_align_op.cpp +++ b/oneflow/user/ops/roi_align_op.cpp @@ -50,7 +50,7 @@ namespace oneflow { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe RoiAlignOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("y", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } /*static*/ Maybe RoiAlignOp::ModifyInputArg(const GetInputArgModifier& GetInputArgModifierFn, @@ -106,7 +106,7 @@ namespace oneflow { CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), ctx->InputDType("x_like", 0)) << Error::TypeError() << "The dy tensor and x_like tensor must have same type"; - *ctx->OutputDType("dx", 0) = ctx->InputDType("x_like", 0); + *ctx->MutOutputDType("dx", 0) = ctx->InputDType("x_like", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/roll_op.cpp b/oneflow/user/ops/roll_op.cpp index a22c27552d0..15904842508 100644 --- a/oneflow/user/ops/roll_op.cpp +++ b/oneflow/user/ops/roll_op.cpp @@ -52,7 +52,7 @@ namespace oneflow { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe RollOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/same_padding_op.cpp b/oneflow/user/ops/same_padding_op.cpp index 267faf5fecf..b2e9bd4dd57 100644 --- a/oneflow/user/ops/same_padding_op.cpp +++ b/oneflow/user/ops/same_padding_op.cpp @@ -71,7 +71,7 @@ namespace oneflow { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe SamePaddingOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("y", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } @@ -109,14 +109,14 @@ namespace oneflow { } /*static*/ Maybe SamePaddingGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { *ctx->OutputShape("dx", 0) = ctx->InputShape("x_like", 0); - *ctx->OutputIsDynamic("dx", 0) = ctx->InputIsDynamic("x_like", 0); + *ctx->MutOutputIsDynamic("dx", 0) = ctx->InputIsDynamic("x_like", 0); return Maybe::Ok(); } /*static*/ Maybe SamePaddingGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe SamePaddingGradOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("dx", 0) = ctx->InputDType("x_like", 0); + *ctx->MutOutputDType("dx", 0) = ctx->InputDType("x_like", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/scalar_logical_op.cpp b/oneflow/user/ops/scalar_logical_op.cpp index 8c0786c2804..f518a1d0f6b 100644 --- a/oneflow/user/ops/scalar_logical_op.cpp +++ b/oneflow/user/ops/scalar_logical_op.cpp @@ -28,14 +28,14 @@ namespace oneflow { } \ /*static*/ Maybe name##Op::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); \ - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); \ + *ctx->MutOutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); \ return Maybe::Ok(); \ } \ /*static*/ Maybe name##Op::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ return InferLogicalTensorDesc(ctx); \ } \ /*static*/ Maybe name##Op::InferDataType(user_op::InferContext* ctx) { \ - *ctx->OutputDType("out", 0) = DataType::kBool; \ + *ctx->MutOutputDType("out", 0) = DataType::kBool; \ return Maybe::Ok(); \ } diff --git a/oneflow/user/ops/scalar_math_op.cpp b/oneflow/user/ops/scalar_math_op.cpp index 3627acde3cf..33eba52d1d0 100644 --- a/oneflow/user/ops/scalar_math_op.cpp +++ b/oneflow/user/ops/scalar_math_op.cpp @@ -43,14 +43,14 @@ Maybe GetSbp4ScalarMul(user_op::SbpContext* ctx) { /*static*/ Maybe op_name##Op::GetSbp(user_op::SbpContext* ctx) { return get_sbp_fn(ctx); } \ /*static*/ Maybe op_name##Op::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); \ - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); \ + *ctx->MutOutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); \ return Maybe::Ok(); \ } \ /*static*/ Maybe op_name##Op::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ return InferLogicalTensorDesc(ctx); \ } \ /*static*/ Maybe op_name##Op::InferDataType(user_op::InferContext* ctx) { \ - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); \ + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); \ return Maybe::Ok(); \ } @@ -80,7 +80,7 @@ IMPLEMENT_SCALAR_MATH_OP_FUNCS(ScalarReversePow, GetSbp4ScalarMath) /*static*/ Maybe ScalarPowGradOp::InferDataType(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(ctx->InputDType("x", 0), ctx->InputDType("dy", 0)) << Error::TypeError() << "Tensors dy and x must have same type"; - *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("dx", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } @@ -101,7 +101,7 @@ IMPLEMENT_SCALAR_MATH_OP_FUNCS(ScalarReversePow, GetSbp4ScalarMath) /*static*/ Maybe ScalarReversePowGradOp::InferDataType(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(ctx->InputDType("x", 0), ctx->InputDType("dy", 0)) << Error::TypeError() << "Tensors dy and x must have same type"; - *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("dx", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/search_sorted_op.cpp b/oneflow/user/ops/search_sorted_op.cpp index 368114c17ec..b9fcc0f1164 100644 --- a/oneflow/user/ops/search_sorted_op.cpp +++ b/oneflow/user/ops/search_sorted_op.cpp @@ -46,9 +46,9 @@ namespace oneflow { /* static */ Maybe SearchSortedOp::InferDataType(user_op::InferContext* ctx) { const bool& out_int32 = ctx->Attr("out_int32"); if (out_int32) { - *ctx->OutputDType("out", 0) = DataType::kInt32; + *ctx->MutOutputDType("out", 0) = DataType::kInt32; } else { - *ctx->OutputDType("out", 0) = DataType::kInt64; + *ctx->MutOutputDType("out", 0) = DataType::kInt64; } return Maybe::Ok(); } @@ -74,9 +74,9 @@ namespace oneflow { /* static */ Maybe SearchSortedScalarOp::InferDataType(user_op::InferContext* ctx) { const bool& out_int32 = ctx->Attr("out_int32"); if (out_int32) { - *ctx->OutputDType("out", 0) = DataType::kInt32; + *ctx->MutOutputDType("out", 0) = DataType::kInt32; } else { - *ctx->OutputDType("out", 0) = DataType::kInt64; + *ctx->MutOutputDType("out", 0) = DataType::kInt64; } return Maybe::Ok(); } diff --git a/oneflow/user/ops/selu_op.cpp b/oneflow/user/ops/selu_op.cpp index e23a95c8526..f73528f9852 100644 --- a/oneflow/user/ops/selu_op.cpp +++ b/oneflow/user/ops/selu_op.cpp @@ -33,7 +33,7 @@ namespace oneflow { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe SeluOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } @@ -63,7 +63,7 @@ namespace oneflow { /*static*/ Maybe SeluGradOp::InferDataType(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), ctx->InputDType("x", 0)) << Error::TypeError() << "Tensors dy and x must have same type"; - *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("dx", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/sigmoid_cross_entropy_op.cpp b/oneflow/user/ops/sigmoid_cross_entropy_op.cpp index 3ec411e429e..2221d06017a 100644 --- a/oneflow/user/ops/sigmoid_cross_entropy_op.cpp +++ b/oneflow/user/ops/sigmoid_cross_entropy_op.cpp @@ -45,7 +45,7 @@ namespace oneflow { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe SigmoidCrossEntropyOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("loss", 0) = ctx->InputDType("prediction", 0); + *ctx->MutOutputDType("loss", 0) = ctx->InputDType("prediction", 0); return Maybe::Ok(); } /*static*/ Maybe SigmoidCrossEntropyOp::ModifyInputArg( @@ -89,7 +89,7 @@ namespace oneflow { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe SigmoidCrossEntropyGradOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("prediction_diff", 0) = ctx->InputDType("prediction", 0); + *ctx->MutOutputDType("prediction_diff", 0) = ctx->InputDType("prediction", 0); return Maybe::Ok(); } /*static*/ Maybe SigmoidCrossEntropyGradOp::ModifyInputArg( diff --git a/oneflow/user/ops/silu_op.cpp b/oneflow/user/ops/silu_op.cpp index 8e35ae69ab1..ee08e239336 100644 --- a/oneflow/user/ops/silu_op.cpp +++ b/oneflow/user/ops/silu_op.cpp @@ -33,7 +33,7 @@ namespace oneflow { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe SiluOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } @@ -65,7 +65,7 @@ namespace oneflow { << Error::TypeError() << "dy and x are expected to have the same dtype, but found " << DataType_Name(ctx->InputDType("dy", 0)) << " and " << DataType_Name(ctx->InputDType("x", 0)); - *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("dx", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/slice_op.cpp b/oneflow/user/ops/slice_op.cpp index 3ae88200258..a3d652414ab 100644 --- a/oneflow/user/ops/slice_op.cpp +++ b/oneflow/user/ops/slice_op.cpp @@ -202,7 +202,7 @@ bool IsFullSlice(int64_t start, int64_t stop, int64_t step, int64_t size) { return Maybe::Ok(); } /*static*/ Maybe SliceOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("y", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } @@ -273,7 +273,7 @@ bool IsFullSlice(int64_t start, int64_t stop, int64_t step, int64_t size) { return Maybe::Ok(); } /*static*/ Maybe SliceGradOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + *ctx->MutOutputDType("dx", 0) = ctx->InputDType("dy", 0); return Maybe::Ok(); } /*static*/ Maybe SliceGradOp::ModifyInputArg(const GetInputArgModifier& GetInputArgModifierFn, diff --git a/oneflow/user/ops/smooth_l1_loss_op.cpp b/oneflow/user/ops/smooth_l1_loss_op.cpp index 51917208a16..85859963ae7 100644 --- a/oneflow/user/ops/smooth_l1_loss_op.cpp +++ b/oneflow/user/ops/smooth_l1_loss_op.cpp @@ -56,7 +56,7 @@ namespace oneflow { << Error::TypeError() << "input and target are expected to have the same dtype, but found " << DataType_Name(input_desc.data_type()) << " and " << DataType_Name(target_desc.data_type()); - *ctx->OutputDType("out", 0) = ctx->InputDType("input", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("input", 0); return Maybe::Ok(); } @@ -115,7 +115,7 @@ namespace oneflow { << Error::TypeError() << "input and target are expected to have the same dtype, but found " << DataType_Name(input_desc.data_type()) << " and " << DataType_Name(target_desc.data_type()); - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + *ctx->MutOutputDType("dx", 0) = ctx->InputDType("dy", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/softmax_cross_entropy_op.cpp b/oneflow/user/ops/softmax_cross_entropy_op.cpp index 1b31f895407..06765de1731 100644 --- a/oneflow/user/ops/softmax_cross_entropy_op.cpp +++ b/oneflow/user/ops/softmax_cross_entropy_op.cpp @@ -52,7 +52,7 @@ namespace oneflow { out_dim_vector.emplace_back(prediction_desc.shape().At(i)); } *ctx->OutputShape("prob", 0) = ctx->InputShape("prediction", 0); - *ctx->OutputIsDynamic("prob", 0) = ctx->InputIsDynamic("prediction", 0); + *ctx->MutOutputIsDynamic("prob", 0) = ctx->InputIsDynamic("prediction", 0); user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); *out_desc->mut_is_dynamic() = prediction_desc.is_dynamic(); *out_desc->mut_shape() = Shape(out_dim_vector); @@ -69,7 +69,7 @@ namespace oneflow { << "label and prediction are expected to have the same dtype, but found " << DataType_Name(label_desc.data_type()) << " and " << DataType_Name(prediction_desc.data_type()); - *ctx->OutputDType("prob", 0) = ctx->InputDType("prediction", 0); + *ctx->MutOutputDType("prob", 0) = ctx->InputDType("prediction", 0); user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); *out_desc->mut_data_type() = prediction_desc.data_type(); return Maybe::Ok(); @@ -119,7 +119,7 @@ namespace oneflow { << Error::RuntimeError() << "The size of label " << label_desc.shape() << " must match the size of prob " << prob_desc.shape(); *ctx->OutputShape("prediction_diff", 0) = ctx->InputShape("prob", 0); - *ctx->OutputIsDynamic("prediction_diff", 0) = ctx->InputIsDynamic("prob", 0); + *ctx->MutOutputIsDynamic("prediction_diff", 0) = ctx->InputIsDynamic("prob", 0); return Maybe::Ok(); } /*static*/ Maybe SoftmaxCrossEntropyGradOp::InferPhysicalTensorDesc( @@ -136,7 +136,7 @@ namespace oneflow { CHECK_EQ_OR_RETURN(dy_desc.data_type(), prob_desc.data_type()) << Error::TypeError() << "dy and prob are expected to have the same dtype, but found " << DataType_Name(dy_desc.data_type()) << " and " << DataType_Name(prob_desc.data_type()); - *ctx->OutputDType("prediction_diff", 0) = ctx->InputDType("prob", 0); + *ctx->MutOutputDType("prediction_diff", 0) = ctx->InputDType("prob", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/softmax_op.cpp b/oneflow/user/ops/softmax_op.cpp index a726d561073..d185d85aec4 100644 --- a/oneflow/user/ops/softmax_op.cpp +++ b/oneflow/user/ops/softmax_op.cpp @@ -36,7 +36,7 @@ namespace oneflow { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe SoftmaxOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } @@ -68,7 +68,7 @@ namespace oneflow { << Error::TypeError() << "dy and y are expected to have the same dtype, but found " << DataType_Name(ctx->InputDType("dy", 0)) << " and " << DataType_Name(ctx->InputDType("y", 0)); - *ctx->OutputDType("dx", 0) = ctx->InputDType("y", 0); + *ctx->MutOutputDType("dx", 0) = ctx->InputDType("y", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/softplus_op.cpp b/oneflow/user/ops/softplus_op.cpp index 2a772b661c0..4731a2cc55f 100644 --- a/oneflow/user/ops/softplus_op.cpp +++ b/oneflow/user/ops/softplus_op.cpp @@ -36,7 +36,7 @@ namespace oneflow { } /* static */ Maybe SoftplusOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } @@ -71,7 +71,7 @@ namespace oneflow { << Error::TypeError() << "dy and x are expected to have the same dtype, but found " << DataType_Name(ctx->InputDType("dy", 0)) << " and " << DataType_Name(ctx->InputDType("x", 0)); - *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("dx", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/softshrink_op.cpp b/oneflow/user/ops/softshrink_op.cpp index 95ec290270b..65427553447 100644 --- a/oneflow/user/ops/softshrink_op.cpp +++ b/oneflow/user/ops/softshrink_op.cpp @@ -36,7 +36,7 @@ namespace oneflow { } /* static */ Maybe SoftShrinkOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } @@ -71,7 +71,7 @@ namespace oneflow { << Error::TypeError() << "dy and y are expected to have the same dtype, but found " << DataType_Name(ctx->InputDType("dy", 0)) << " and " << DataType_Name(ctx->InputDType("y", 0)); - *ctx->OutputDType("dx", 0) = ctx->InputDType("y", 0); + *ctx->MutOutputDType("dx", 0) = ctx->InputDType("y", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/softsign_op.cpp b/oneflow/user/ops/softsign_op.cpp index 61e45f781e6..6b15b3479e4 100644 --- a/oneflow/user/ops/softsign_op.cpp +++ b/oneflow/user/ops/softsign_op.cpp @@ -33,7 +33,7 @@ namespace oneflow { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe SoftsignOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } @@ -65,7 +65,7 @@ namespace oneflow { << Error::TypeError() << "dy and x are expected to have the same dtype, but found " << DataType_Name(ctx->InputDType("dy", 0)) << " and " << DataType_Name(ctx->InputDType("x", 0)); - *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("dx", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/sort_op.cpp b/oneflow/user/ops/sort_op.cpp index f2dd5e6f89b..3b1fb41506f 100644 --- a/oneflow/user/ops/sort_op.cpp +++ b/oneflow/user/ops/sort_op.cpp @@ -35,7 +35,7 @@ namespace oneflow { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe SortOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } /*static*/ Maybe SortOp::CheckAttr(const user_op::UserOpDefWrapper&, diff --git a/oneflow/user/ops/sparse_cross_entropy_op.cpp b/oneflow/user/ops/sparse_cross_entropy_op.cpp index b661910fe8c..e837b70a557 100644 --- a/oneflow/user/ops/sparse_cross_entropy_op.cpp +++ b/oneflow/user/ops/sparse_cross_entropy_op.cpp @@ -63,7 +63,7 @@ Maybe InferGradTensorDescFn(user_op::InferContext* ctx) { << Error::RuntimeError() << "The size of dy " << dy_desc.shape() << " must match the size of label " << label_desc.shape(); *ctx->OutputShape("prediction_diff", 0) = prediction_desc.shape(); - *ctx->OutputIsDynamic("prediction_diff", 0) = prediction_desc.is_dynamic(); + *ctx->MutOutputIsDynamic("prediction_diff", 0) = prediction_desc.is_dynamic(); return Maybe::Ok(); } @@ -89,7 +89,7 @@ Maybe InferDataTypeGrad(user_op::InferContext* ctx) { << Error::TypeError() << "dy and prediction are expected to have the same dtype, but found " << DataType_Name(dy_desc.data_type()) << " and " << DataType_Name(prediction_desc.data_type()); - *ctx->OutputDType("prediction_diff", 0) = prediction_desc.data_type(); + *ctx->MutOutputDType("prediction_diff", 0) = prediction_desc.data_type(); return Maybe::Ok(); } diff --git a/oneflow/user/ops/sparse_softmax_cross_entropy_op.cpp b/oneflow/user/ops/sparse_softmax_cross_entropy_op.cpp index 0d77af3f218..1416d257ba9 100644 --- a/oneflow/user/ops/sparse_softmax_cross_entropy_op.cpp +++ b/oneflow/user/ops/sparse_softmax_cross_entropy_op.cpp @@ -41,7 +41,7 @@ Maybe InferTensorDescFn(user_op::InferContext* ctx) { << Error::RuntimeError() << "The size of prediction (" << prediction_desc.shape().At(i) << ") must match the size of label (" << label_desc.shape().At(i) << ") at dimension " << i; } - *ctx->OutputIsDynamic("prob", 0) = prediction_desc.is_dynamic(); + *ctx->MutOutputIsDynamic("prob", 0) = prediction_desc.is_dynamic(); // 'prob' is just for compute prediction's grad, prob's grad will be ignored *ctx->OutputShape("prob", 0) = prediction_desc.shape(); user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); @@ -76,7 +76,7 @@ Maybe InferGradTensorDescFn(user_op::InferContext* ctx) { << Error::RuntimeError() << "The size of dy " << dy_desc.shape() << " must match the size of label " << label_desc.shape(); *ctx->OutputShape("prediction_diff", 0) = prob_desc.shape(); - *ctx->OutputIsDynamic("prediction_diff", 0) = prob_desc.is_dynamic(); + *ctx->MutOutputIsDynamic("prediction_diff", 0) = prob_desc.is_dynamic(); return Maybe::Ok(); } @@ -85,8 +85,8 @@ Maybe InferDataType(user_op::InferContext* ctx) { CHECK_OR_RETURN(IsIndexDataType(label_desc.data_type())) << Error::TypeError() << "The dtype of label must be integer, but found " << DataType_Name(label_desc.data_type()); - *ctx->OutputDType("prob", 0) = ctx->InputDType("prediction", 0); - *ctx->OutputDType("out", 0) = ctx->InputDType("prediction", 0); + *ctx->MutOutputDType("prob", 0) = ctx->InputDType("prediction", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("prediction", 0); return Maybe::Ok(); } @@ -100,7 +100,7 @@ Maybe InferDataTypeGrad(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(dy_desc.data_type(), prob_desc.data_type()) << Error::TypeError() << "dy and prob are expected to have the same dtype, but found " << DataType_Name(dy_desc.data_type()) << " and " << DataType_Name(prob_desc.data_type()); - *ctx->OutputDType("prediction_diff", 0) = prob_desc.data_type(); + *ctx->MutOutputDType("prediction_diff", 0) = prob_desc.data_type(); return Maybe::Ok(); } diff --git a/oneflow/user/ops/sqrt_square_sum_op.cpp b/oneflow/user/ops/sqrt_square_sum_op.cpp index f8c6b43ca5b..4766f0628ec 100644 --- a/oneflow/user/ops/sqrt_square_sum_op.cpp +++ b/oneflow/user/ops/sqrt_square_sum_op.cpp @@ -34,7 +34,7 @@ namespace oneflow { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe SqrtSquareSumOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("y", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/square_sum_op.cpp b/oneflow/user/ops/square_sum_op.cpp index bb53097df89..3748c184770 100644 --- a/oneflow/user/ops/square_sum_op.cpp +++ b/oneflow/user/ops/square_sum_op.cpp @@ -34,7 +34,7 @@ namespace oneflow { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe SquareSumOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("y", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/squeeze_op.cpp b/oneflow/user/ops/squeeze_op.cpp index d6c9cb111a4..70ac20c4c89 100644 --- a/oneflow/user/ops/squeeze_op.cpp +++ b/oneflow/user/ops/squeeze_op.cpp @@ -78,7 +78,7 @@ Maybe CheckAndLabelAxesToSqueezeMinusOne(const AxisVector& axes, DimVector return InferLogicalTensorDesc(ctx); } /*static*/ Maybe SqueezeOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/ssp_variable_proxy_op.cpp b/oneflow/user/ops/ssp_variable_proxy_op.cpp index 9a5a31262a7..1df65634f54 100644 --- a/oneflow/user/ops/ssp_variable_proxy_op.cpp +++ b/oneflow/user/ops/ssp_variable_proxy_op.cpp @@ -39,8 +39,8 @@ namespace oneflow { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe SspVariableProxyOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("ref", 0) = ctx->InputDType("var", 0); - *ctx->OutputDType("value", 0) = ctx->InputDType("var", 0); + *ctx->MutOutputDType("ref", 0) = ctx->InputDType("var", 0); + *ctx->MutOutputDType("value", 0) = ctx->InputDType("var", 0); return Maybe::Ok(); } /*static*/ Maybe SspVariableProxyOp::ModifyOutputArg( diff --git a/oneflow/user/ops/tf_pool_op.cpp b/oneflow/user/ops/tf_pool_op.cpp index 39afc8478b8..e71989fd086 100644 --- a/oneflow/user/ops/tf_pool_op.cpp +++ b/oneflow/user/ops/tf_pool_op.cpp @@ -52,17 +52,17 @@ TensorDescInferFn MakeFwTensorDescInferFn(const int32_t dim) { Maybe BwTensorDescInferFn(user_op::InferContext* ctx) { *ctx->OutputShape("dx", 0) = ctx->InputShape("x", 0); - *ctx->OutputIsDynamic("dx", 0) = ctx->InputIsDynamic("x", 0); + *ctx->MutOutputIsDynamic("dx", 0) = ctx->InputIsDynamic("x", 0); return Maybe::Ok(); } Maybe FwInferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("y", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } Maybe BwInferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("dx", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/tf_prelu_op.cpp b/oneflow/user/ops/tf_prelu_op.cpp index b4880e201e7..342801b9272 100644 --- a/oneflow/user/ops/tf_prelu_op.cpp +++ b/oneflow/user/ops/tf_prelu_op.cpp @@ -54,7 +54,7 @@ namespace oneflow { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe TfPreluOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("y", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } @@ -103,15 +103,15 @@ namespace oneflow { *dx_desc->mut_shape() = x_desc.shape(); *dx_desc->mut_is_dynamic() = x_desc.is_dynamic(); *ctx->OutputShape("alpha_diff", 0) = alpha_desc.shape(); - *ctx->OutputIsDynamic("alpha_diff", 0) = alpha_desc.is_dynamic(); + *ctx->MutOutputIsDynamic("alpha_diff", 0) = alpha_desc.is_dynamic(); return Maybe::Ok(); } /*static*/ Maybe TfPreluGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe TfPreluGradOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); - *ctx->OutputDType("alpha_diff", 0) = ctx->InputDType("alpha", 0); + *ctx->MutOutputDType("dx", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("alpha_diff", 0) = ctx->InputDType("alpha", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/threshold_op.cpp b/oneflow/user/ops/threshold_op.cpp index 3cf10ab9dae..3a3b2360bc5 100644 --- a/oneflow/user/ops/threshold_op.cpp +++ b/oneflow/user/ops/threshold_op.cpp @@ -36,7 +36,7 @@ namespace oneflow { } /* static */ Maybe ThresholdOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } @@ -67,7 +67,7 @@ namespace oneflow { /* static */ Maybe ThresholdGradOp::InferDataType(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), ctx->InputDType("x", 0)); - *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("dx", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/to_contiguous_op.cpp b/oneflow/user/ops/to_contiguous_op.cpp index 95a80c3e1b6..c51b21faaf6 100644 --- a/oneflow/user/ops/to_contiguous_op.cpp +++ b/oneflow/user/ops/to_contiguous_op.cpp @@ -32,7 +32,7 @@ namespace oneflow { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe ToContiguousOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/top_k_op.cpp b/oneflow/user/ops/top_k_op.cpp index 0bcf295d5bd..cfa00d070cb 100644 --- a/oneflow/user/ops/top_k_op.cpp +++ b/oneflow/user/ops/top_k_op.cpp @@ -39,7 +39,7 @@ namespace oneflow { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe TopKOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = DataType::kInt64; + *ctx->MutOutputDType("out", 0) = DataType::kInt64; return Maybe::Ok(); } diff --git a/oneflow/user/ops/transpose_ops.cpp b/oneflow/user/ops/transpose_ops.cpp index 9d8130e6efb..2b483d8f449 100644 --- a/oneflow/user/ops/transpose_ops.cpp +++ b/oneflow/user/ops/transpose_ops.cpp @@ -60,7 +60,7 @@ void CheckIsPerm(const std::vector& perm) { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe TransposeOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("output", 0) = ctx->InputDType("input", 0); + *ctx->MutOutputDType("output", 0) = ctx->InputDType("input", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/tuple_identity_op.cpp b/oneflow/user/ops/tuple_identity_op.cpp index dd98f2fef74..7b9151f5146 100644 --- a/oneflow/user/ops/tuple_identity_op.cpp +++ b/oneflow/user/ops/tuple_identity_op.cpp @@ -27,7 +27,7 @@ namespace oneflow { CHECK_EQ_OR_RETURN(ctx->output_size("out"), in_size); for (int64_t i = 0; i < in_size; ++i) { *ctx->OutputShape("out", i) = ctx->InputShape("in", i); - *ctx->IsDynamic4ArgNameAndIndex("out", i) = ctx->InputIsDynamic("in", i); + *ctx->MutIsDynamic4ArgNameAndIndex("out", i) = ctx->InputIsDynamic("in", i); } return Maybe::Ok(); } @@ -37,7 +37,9 @@ namespace oneflow { /*static*/ Maybe TupleIdentityOp::InferDataType(user_op::InferContext* ctx) { const int64_t in_size = ctx->input_size("in"); CHECK_EQ_OR_RETURN(ctx->output_size("out"), in_size); - for (int64_t i = 0; i < in_size; ++i) { *ctx->OutputDType("out", i) = ctx->InputDType("in", i); } + for (int64_t i = 0; i < in_size; ++i) { + *ctx->MutOutputDType("out", i) = ctx->InputDType("in", i); + } return Maybe::Ok(); } /*static*/ Maybe TupleIdentityOp::InferSbpSignature( diff --git a/oneflow/user/ops/two_stage_reduce_ops.cpp b/oneflow/user/ops/two_stage_reduce_ops.cpp index 9fbb79e1da0..1e3f1239cb8 100644 --- a/oneflow/user/ops/two_stage_reduce_ops.cpp +++ b/oneflow/user/ops/two_stage_reduce_ops.cpp @@ -23,9 +23,9 @@ namespace oneflow { namespace { Maybe InferReduceDeviceStageDtypeFn(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - *ctx->OutputDType("mask", 0) = DataType::kBool; - *ctx->OutputDType("count", 0) = DataType::kInt32; + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("mask", 0) = DataType::kBool; + *ctx->MutOutputDType("count", 0) = DataType::kInt32; return Maybe::Ok(); } @@ -90,7 +90,7 @@ Maybe InferReduceDeviceStagePhysicalTensorDescFn(user_op::InferContext* ct Maybe InferReduceDeviceStageGradDtypeFn(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(ctx->InputDType("mask", 0), DataType::kBool); CHECK_EQ_OR_RETURN(ctx->InputDType("count", 0), DataType::kInt32); - *ctx->OutputDType("in_diff", 0) = ctx->InputDType("out_diff", 0); + *ctx->MutOutputDType("in_diff", 0) = ctx->InputDType("out_diff", 0); return Maybe::Ok(); } @@ -102,8 +102,8 @@ Maybe InferReduceDeviceStageGradTensorDescFn(user_op::InferContext* ctx) { Maybe InferReduceGlobalStageDtypeFn(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(ctx->InputDType("device_count", 0), DataType::kInt32); - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - *ctx->OutputDType("mask", 0) = DataType::kBool; + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->MutOutputDType("mask", 0) = DataType::kBool; return Maybe::Ok(); } @@ -140,7 +140,7 @@ Maybe InferReduceGlobalStageGradDtypeFn(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(ctx->InputDType("mask", 0), DataType::kBool); CHECK_EQ_OR_RETURN(ctx->InputDType("device_count", 0), DataType::kInt32); - *ctx->OutputDType("in_diff", 0) = ctx->InputDType("out_diff", 0); + *ctx->MutOutputDType("in_diff", 0) = ctx->InputDType("out_diff", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/unfold_fold_op.cpp b/oneflow/user/ops/unfold_fold_op.cpp index 0560561604c..3de3f5acd3f 100644 --- a/oneflow/user/ops/unfold_fold_op.cpp +++ b/oneflow/user/ops/unfold_fold_op.cpp @@ -63,7 +63,7 @@ Maybe UnfoldTensorDescInferFn(user_op::InferContext* ctx) { } Maybe SetUnfoldDTypeFn(user_op::InferContext* ctx) { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("y", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } @@ -123,7 +123,7 @@ Maybe FoldTensorDescInferFn(user_op::InferContext* ctx) { } Maybe FoldDTypeFn(user_op::InferContext* ctx) { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("y", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/unfold_tensor_op.cpp b/oneflow/user/ops/unfold_tensor_op.cpp index 52d1c068e6b..e6aafcae05c 100644 --- a/oneflow/user/ops/unfold_tensor_op.cpp +++ b/oneflow/user/ops/unfold_tensor_op.cpp @@ -64,7 +64,7 @@ namespace oneflow { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe UnfoldTensorOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("y", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } @@ -94,7 +94,7 @@ namespace oneflow { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe UnfoldTensorGradOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + *ctx->MutOutputDType("dx", 0) = ctx->InputDType("dy", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/unsorted_segment_sum_op.cpp b/oneflow/user/ops/unsorted_segment_sum_op.cpp index 5df5e81e451..145bdca888e 100644 --- a/oneflow/user/ops/unsorted_segment_sum_op.cpp +++ b/oneflow/user/ops/unsorted_segment_sum_op.cpp @@ -69,7 +69,7 @@ namespace oneflow { } /*static*/ Maybe UnsortedSegmentSumOp::InferDataType(user_op::InferContext* ctx) { CHECK_OR_RETURN(IsIndexDataType(ctx->InputDType("segment_ids", 0))); - *ctx->OutputDType("out", 0) = ctx->InputDType("data", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("data", 0); return Maybe::Ok(); } /*static*/ Maybe UnsortedSegmentSumOp::ModifyInputArg( @@ -164,7 +164,7 @@ REGISTER_USER_OP_GRAD("unsorted_segment_sum") CHECK_EQ_OR_RETURN(like_shape.At(i), data_shape.At(i + segment_ids_shape.NumAxes() - 1)); } *ctx->OutputShape("out", 0) = ctx->InputShape("like", 0); - *ctx->IsDynamic4ArgNameAndIndex("out", 0) = ctx->InputIsDynamic("like", 0); + *ctx->MutIsDynamic4ArgNameAndIndex("out", 0) = ctx->InputIsDynamic("like", 0); return Maybe::Ok(); } /*static*/ Maybe UnsortedSegmentSumLikeOp::InferPhysicalTensorDesc( @@ -176,7 +176,7 @@ REGISTER_USER_OP_GRAD("unsorted_segment_sum") const user_op::TensorDesc& like = ctx->InputTensorDesc("like", 0); CHECK_EQ_OR_RETURN(data.data_type(), like.data_type()); CHECK_OR_RETURN(IsIndexDataType(ctx->InputDType("segment_ids", 0))); - *ctx->OutputDType("out", 0) = ctx->InputDType("like", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("like", 0); return Maybe::Ok(); } /*static*/ Maybe UnsortedSegmentSumLikeOp::ModifyInputArg( diff --git a/oneflow/user/ops/upsample_op.cpp b/oneflow/user/ops/upsample_op.cpp index e1d05c1b097..badb48d129a 100644 --- a/oneflow/user/ops/upsample_op.cpp +++ b/oneflow/user/ops/upsample_op.cpp @@ -43,7 +43,7 @@ namespace oneflow { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe UpsampleLinear1DOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("y", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } @@ -71,7 +71,7 @@ namespace oneflow { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe UpsampleNearest1DOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("y", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } @@ -102,7 +102,7 @@ namespace oneflow { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe UpsampleNearest2DOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("y", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } @@ -133,7 +133,7 @@ namespace oneflow { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe UpsampleBilinear2DOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("y", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } @@ -164,7 +164,7 @@ namespace oneflow { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe UpsampleBicubic2DOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("y", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } @@ -197,7 +197,7 @@ namespace oneflow { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe UpsampleNearest3DOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("y", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } @@ -230,7 +230,7 @@ namespace oneflow { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe UpsampleTrilinear3DOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + *ctx->MutOutputDType("y", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } @@ -255,7 +255,7 @@ namespace oneflow { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe UpsampleLinear1DGradOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + *ctx->MutOutputDType("dx", 0) = ctx->InputDType("dy", 0); return Maybe::Ok(); } @@ -281,7 +281,7 @@ namespace oneflow { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe UpsampleNearest1DGradOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + *ctx->MutOutputDType("dx", 0) = ctx->InputDType("dy", 0); return Maybe::Ok(); } @@ -307,7 +307,7 @@ namespace oneflow { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe UpsampleNearest2DGradOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + *ctx->MutOutputDType("dx", 0) = ctx->InputDType("dy", 0); return Maybe::Ok(); } @@ -334,7 +334,7 @@ namespace oneflow { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe UpsampleBilinear2DGradOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + *ctx->MutOutputDType("dx", 0) = ctx->InputDType("dy", 0); return Maybe::Ok(); } @@ -360,7 +360,7 @@ namespace oneflow { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe UpsampleBicubic2DGradOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + *ctx->MutOutputDType("dx", 0) = ctx->InputDType("dy", 0); return Maybe::Ok(); } @@ -386,7 +386,7 @@ namespace oneflow { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe UpsampleNearest3DGradOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + *ctx->MutOutputDType("dx", 0) = ctx->InputDType("dy", 0); return Maybe::Ok(); } @@ -413,7 +413,7 @@ namespace oneflow { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe UpsampleTrilinear3DGradOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + *ctx->MutOutputDType("dx", 0) = ctx->InputDType("dy", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/util_ops.cpp b/oneflow/user/ops/util_ops.cpp index 0be4ce5f115..5f070b9fe6b 100644 --- a/oneflow/user/ops/util_ops.cpp +++ b/oneflow/user/ops/util_ops.cpp @@ -38,7 +38,7 @@ namespace oneflow { } /* static */ Maybe IsNanOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = DataType::kBool; + *ctx->MutOutputDType("out", 0) = DataType::kBool; return Maybe::Ok(); } @@ -62,7 +62,7 @@ namespace oneflow { } /* static */ Maybe IsInfOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = DataType::kBool; + *ctx->MutOutputDType("out", 0) = DataType::kBool; return Maybe::Ok(); } diff --git a/oneflow/user/ops/variance_op.cpp b/oneflow/user/ops/variance_op.cpp index c1e578e6947..144f9bdd96a 100644 --- a/oneflow/user/ops/variance_op.cpp +++ b/oneflow/user/ops/variance_op.cpp @@ -41,7 +41,7 @@ Maybe VarOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { } Maybe VarOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("output", 0) = ctx->InputDType("input", 0); + *ctx->MutOutputDType("output", 0) = ctx->InputDType("input", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/vector_matrix_product_op.cpp b/oneflow/user/ops/vector_matrix_product_op.cpp index 834ace4ab4c..3d45ae148c5 100644 --- a/oneflow/user/ops/vector_matrix_product_op.cpp +++ b/oneflow/user/ops/vector_matrix_product_op.cpp @@ -34,7 +34,7 @@ Maybe InferDataType4VectorMatrixProduct(user_op::InferContext* ctx) { const DataType& dtype = ctx->InputDType("a", 0); CHECK_EQ_OR_RETURN(ctx->InputDType("b", 0), dtype) << "Matrix A datatype should be equal to Vector B. "; - *ctx->OutputDType("out", 0) = dtype; + *ctx->MutOutputDType("out", 0) = dtype; return Maybe::Ok(); } @@ -64,7 +64,7 @@ Maybe InferTensorDesc4VectorMatrixProductGradB(user_op::InferContext* ctx) Maybe InferDataType4Grad(user_op::InferContext* ctx) { const DataType& dtype = ctx->InputDType("dy", 0); - *ctx->OutputDType("dx", 0) = dtype; + *ctx->MutOutputDType("dx", 0) = dtype; return Maybe::Ok(); } diff --git a/oneflow/user/ops/where_op.cpp b/oneflow/user/ops/where_op.cpp index e49ffb19fe6..b04216941e0 100644 --- a/oneflow/user/ops/where_op.cpp +++ b/oneflow/user/ops/where_op.cpp @@ -213,7 +213,7 @@ Maybe GetWhereInputArgModify(const GetInputArgModifier& GetInputArgModifie CHECK_OR_RETURN(IsBoolDataType(cond_dtype) || IsIntegralDataType(cond_dtype)); const DataType& x_dtype = ctx->InputDType("x", 0); CHECK_EQ_OR_RETURN(x_dtype, ctx->InputDType("y", 0)); - *ctx->OutputDType("out", 0) = x_dtype; + *ctx->MutOutputDType("out", 0) = x_dtype; return Maybe::Ok(); } /*static*/ Maybe WhereOp::ModifyInputArg(const GetInputArgModifier& f, @@ -244,7 +244,7 @@ Maybe GetWhereInputArgModify(const GetInputArgModifier& GetInputArgModifie CHECK_EQ_OR_RETURN(y_dtype, GetDataType::value) << "expected scalar type " << GetDataType::value << "but found " << y_dtype; } - *ctx->OutputDType("out", 0) = y_dtype; + *ctx->MutOutputDType("out", 0) = y_dtype; return Maybe::Ok(); } /*static*/ Maybe WhereScalarXOp::ModifyInputArg(const GetInputArgModifier& f, @@ -275,7 +275,7 @@ Maybe GetWhereInputArgModify(const GetInputArgModifier& GetInputArgModifie CHECK_EQ_OR_RETURN(x_dtype, GetDataType::value) << "expected scalar type " << GetDataType::value << "but found " << x_dtype; } - *ctx->OutputDType("out", 0) = x_dtype; + *ctx->MutOutputDType("out", 0) = x_dtype; return Maybe::Ok(); } /*static*/ Maybe WhereScalarYOp::ModifyInputArg(const GetInputArgModifier& f, @@ -296,11 +296,11 @@ Maybe GetWhereInputArgModify(const GetInputArgModifier& GetInputArgModifie const DataType& cond_dtype = ctx->InputDType("condition", 0); CHECK_OR_RETURN(IsBoolDataType(cond_dtype) || IsIntegralDataType(cond_dtype)); if (ctx->Attr("has_x_bool_operand") && ctx->Attr("has_y_bool_operand")) { - *ctx->OutputDType("out", 0) = GetDataType::value; + *ctx->MutOutputDType("out", 0) = GetDataType::value; } else if (ctx->Attr("has_x_int_operand") && ctx->Attr("has_y_int_operand")) { - *ctx->OutputDType("out", 0) = GetDataType::value; + *ctx->MutOutputDType("out", 0) = GetDataType::value; } else if (ctx->Attr("has_x_float_operand") && ctx->Attr("has_y_float_operand")) { - *ctx->OutputDType("out", 0) = GetDataType::value; + *ctx->MutOutputDType("out", 0) = GetDataType::value; } else { UNIMPLEMENTED(); } diff --git a/oneflow/user/ops/zero_like_op.cpp b/oneflow/user/ops/zero_like_op.cpp index ad648779684..d4c72bacc4f 100644 --- a/oneflow/user/ops/zero_like_op.cpp +++ b/oneflow/user/ops/zero_like_op.cpp @@ -40,7 +40,7 @@ namespace oneflow { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe ZeroLikeOp::InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("like", 0); + *ctx->MutOutputDType("out", 0) = ctx->InputDType("like", 0); return Maybe::Ok(); } /*static*/ Maybe ZeroLikeOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { From 955e2c9ec33b48bf3729a6e4053105af7cda141c Mon Sep 17 00:00:00 2001 From: clackhan Date: Thu, 21 Jul 2022 15:29:43 +0800 Subject: [PATCH 44/67] define_mut_output_dtype_and_mut_output_tensor_desc --- oneflow/core/framework/infer_util.cpp | 2 +- oneflow/core/framework/infer_util.h | 3 +- oneflow/core/framework/op_expr.cpp | 7 +++- oneflow/core/framework/user_op_registry.cpp | 2 +- oneflow/core/kernel/user_kernel.cpp | 7 +++- oneflow/core/operator/user_op.cpp | 16 ++++--- oneflow/user/kernels/conv_cudnn_kernels.cpp | 40 +++++++++--------- oneflow/user/kernels/conv_kernels.cpp | 4 +- oneflow/user/kernels/deconv_cudnn_kernel.cpp | 4 +- oneflow/user/kernels/fused_gru_cell_kernel.cu | 2 +- .../user/kernels/fused_lstm_cell_kernel.cu | 2 +- oneflow/user/kernels/group_conv_kernel.cpp | 4 +- .../kernels/nccl_logical_2d_sbp_kernels.cpp | 6 +-- oneflow/user/kernels/nccl_logical_kernels.cpp | 12 +++--- oneflow/user/kernels/reduce_kernel.cpp | 2 +- oneflow/user/kernels/reduce_like_kernels.cpp | 2 +- oneflow/user/kernels/stateful_opkernel.cpp | 15 +++++-- oneflow/user/ops/add_n_op.cpp | 4 +- oneflow/user/ops/affine_grid_op.cpp | 8 ++-- oneflow/user/ops/amp_white_identity_op.cpp | 4 +- oneflow/user/ops/arg_where_op.cpp | 8 ++-- oneflow/user/ops/as_strided_op.cpp | 4 +- oneflow/user/ops/avg_pool_op.cpp | 4 +- oneflow/user/ops/batch_gather_op.cpp | 4 +- oneflow/user/ops/bernoulli_op.cpp | 4 +- oneflow/user/ops/binary_cross_entropy_op.cpp | 4 +- .../binary_cross_entropy_with_logits_op.cpp | 4 +- ...oss_entropy_with_logits_reduce_mean_op.cpp | 4 +- oneflow/user/ops/cast_like_op.cpp | 2 +- oneflow/user/ops/cast_op.cpp | 4 +- oneflow/user/ops/cast_to_static_shape_op.cpp | 2 +- oneflow/user/ops/coco_reader_op.cpp | 42 +++++++++---------- oneflow/user/ops/combined_margin_loss_op.cpp | 2 +- oneflow/user/ops/concat_op.cpp | 4 +- oneflow/user/ops/conv_op.cpp | 10 ++--- oneflow/user/ops/count_not_finite_op.cpp | 8 ++-- .../cublas_bias_add_relu_matmul_grad_op.cpp | 4 +- .../cublas_fused_matmul_bias_add_grad_op.cpp | 4 +- oneflow/user/ops/cublas_fused_mlp_op.cpp | 6 +-- oneflow/user/ops/deconv_op.cpp | 2 +- oneflow/user/ops/diag_op.cpp | 4 +- oneflow/user/ops/diagonal_op.cpp | 4 +- oneflow/user/ops/dim_gather_op.cpp | 4 +- oneflow/user/ops/dim_scatter_ops.cpp | 4 +- .../ops/elementwise_maximum_minimum_ops.cpp | 8 ++-- oneflow/user/ops/embedding_op.cpp | 4 +- oneflow/user/ops/flatten_op.cpp | 2 +- oneflow/user/ops/flip_op.cpp | 2 +- oneflow/user/ops/fused_cast_scale_op.cpp | 4 +- .../fused_matmul_bias_add_relu_dropout_op.cpp | 6 +-- .../fused_scale_mask_softmax_dropout_op.cpp | 4 +- .../user/ops/fused_scale_mask_softmax_op.cpp | 4 +- ...fused_scale_tril_softmax_mask_scale_op.cpp | 4 +- oneflow/user/ops/gather_op.cpp | 4 +- oneflow/user/ops/gpt_data_loader_op.cpp | 4 +- oneflow/user/ops/grid_sample_op.cpp | 6 +-- oneflow/user/ops/image_batch_align_op.cpp | 4 +- oneflow/user/ops/image_decode_op.cpp | 4 +- oneflow/user/ops/image_preprocess_ops.cpp | 14 +++---- oneflow/user/ops/image_resize_ops.cpp | 20 ++++----- oneflow/user/ops/image_target_resize_op.cpp | 12 +++--- oneflow/user/ops/in_top_k_op.cpp | 4 +- .../user/ops/indexed_slices_reduce_sum_op.cpp | 8 ++-- oneflow/user/ops/kl_div_op.cpp | 4 +- oneflow/user/ops/layer_norm_op.cpp | 24 +++++------ .../user/ops/math_binary_broadcast_ops.cpp | 2 +- oneflow/user/ops/matmul_op.cpp | 6 +-- oneflow/user/ops/max_pool_op.cpp | 8 ++-- oneflow/user/ops/mutable_cast_once_op.cpp | 4 +- oneflow/user/ops/narrow_op.cpp | 4 +- oneflow/user/ops/nll_op.cpp | 6 +-- oneflow/user/ops/normalization_op.cpp | 20 ++++----- oneflow/user/ops/ofrecord_decoder_ops.cpp | 16 +++---- ...frecord_image_classification_reader_op.cpp | 8 ++-- oneflow/user/ops/ofrecord_reader_op.cpp | 4 +- oneflow/user/ops/one_hot_op.cpp | 4 +- oneflow/user/ops/onerec_decoder_op.cpp | 4 +- oneflow/user/ops/onerec_reader_op.cpp | 2 +- oneflow/user/ops/pack_op.cpp | 2 +- oneflow/user/ops/partial_fc_sample_op.cpp | 16 +++---- oneflow/user/ops/reduce_like_ops.cpp | 2 +- oneflow/user/ops/repeat_interleave_op.cpp | 2 +- oneflow/user/ops/reshape_op.cpp | 4 +- oneflow/user/ops/roc_auc_score_op.cpp | 2 +- oneflow/user/ops/same_padding_op.cpp | 2 +- oneflow/user/ops/scalar_by_tensor_op.cpp | 4 +- oneflow/user/ops/sigmoid_cross_entropy_op.cpp | 4 +- oneflow/user/ops/slice_op.cpp | 6 +-- oneflow/user/ops/smooth_l1_loss_op.cpp | 4 +- oneflow/user/ops/softmax_cross_entropy_op.cpp | 4 +- oneflow/user/ops/sparse_cross_entropy_op.cpp | 4 +- .../ops/sparse_softmax_cross_entropy_op.cpp | 2 +- oneflow/user/ops/split_like_op.cpp | 4 +- oneflow/user/ops/sqrt_square_sum_op.cpp | 2 +- oneflow/user/ops/square_sum_op.cpp | 6 +-- oneflow/user/ops/stack_op.cpp | 8 ++-- oneflow/user/ops/tensor_buffer_ops.cpp | 20 ++++----- oneflow/user/ops/tf_pool_op.cpp | 2 +- oneflow/user/ops/tf_prelu_op.cpp | 4 +- oneflow/user/ops/transpose_ops.cpp | 2 +- oneflow/user/ops/tril_op.cpp | 8 ++-- oneflow/user/ops/triu_op.cpp | 4 +- oneflow/user/ops/unfold_tensor_op.cpp | 2 +- oneflow/user/ops/unique_with_counts_op.cpp | 16 +++---- oneflow/user/ops/unpack_op.cpp | 4 +- .../ops/unsorted_batch_segment_sum_op.cpp | 4 +- oneflow/user/ops/upsample_op.cpp | 14 +++---- 107 files changed, 350 insertions(+), 330 deletions(-) diff --git a/oneflow/core/framework/infer_util.cpp b/oneflow/core/framework/infer_util.cpp index 27648f0d489..287b89b1675 100644 --- a/oneflow/core/framework/infer_util.cpp +++ b/oneflow/core/framework/infer_util.cpp @@ -68,7 +68,7 @@ Maybe TensorDescInferFnUtil::InOutCorrespond(InferContext* ctx) { for (size_t i = 0; i < ctx->inputs().size(); ++i) { const auto& input_arg = ctx->inputs().at(i); const auto& output_arg = ctx->outputs().at(i); - *ctx->OutputTensorDesc(output_arg.first, output_arg.second) = + *ctx->MutOutputTensorDesc(output_arg.first, output_arg.second) = ctx->InputTensorDesc(input_arg.first, input_arg.second); } return Maybe::Ok(); diff --git a/oneflow/core/framework/infer_util.h b/oneflow/core/framework/infer_util.h index d91114fae54..a66bf9830f9 100644 --- a/oneflow/core/framework/infer_util.h +++ b/oneflow/core/framework/infer_util.h @@ -39,7 +39,8 @@ class InferContext { virtual ~InferContext() = default; virtual const TensorDesc& InputTensorDesc(const std::string&, int32_t) const = 0; - virtual TensorDesc* OutputTensorDesc(const std::string&, int32_t) = 0; + virtual const TensorDesc& OutputTensorDesc(const std::string&, int32_t) const = 0; + virtual TensorDesc* MutOutputTensorDesc(const std::string&, int32_t) = 0; virtual const TensorDesc* LogicalTensorDesc4ArgNameAndIndex(const std::string&, int32_t) const = 0; virtual const Shape& InputShape(const std::string&, int32_t) const = 0; diff --git a/oneflow/core/framework/op_expr.cpp b/oneflow/core/framework/op_expr.cpp index f4e5b3ee871..a0a6bc774ca 100644 --- a/oneflow/core/framework/op_expr.cpp +++ b/oneflow/core/framework/op_expr.cpp @@ -193,8 +193,11 @@ class UserOpExprInferContext : public user_op::InferContext { int32_t index) const override { return *const_cast(this)->TensorDesc4ArgNameAndIndex(arg_name, index); } - - user_op::TensorDesc* OutputTensorDesc(const std::string& name, int32_t index) override { + const user_op::TensorDesc& OutputTensorDesc(const std::string& arg_name, + int32_t index) const override { + return *const_cast(this)->TensorDesc4ArgNameAndIndex(arg_name, index); + } + user_op::TensorDesc* MutOutputTensorDesc(const std::string& name, int32_t index) override { return TensorDesc4ArgNameAndIndex(name, index); } diff --git a/oneflow/core/framework/user_op_registry.cpp b/oneflow/core/framework/user_op_registry.cpp index 886b3084cee..c8fc8d0a436 100644 --- a/oneflow/core/framework/user_op_registry.cpp +++ b/oneflow/core/framework/user_op_registry.cpp @@ -228,7 +228,7 @@ Maybe OpRegistry::Finish() { == in_physical.shape()); } for (const auto& pair : ctx->outputs()) { - TensorDesc* desc = ctx->OutputTensorDesc(pair.first, pair.second); + TensorDesc* desc = ctx->MutOutputTensorDesc(pair.first, pair.second); *desc = *ctx->LogicalTensorDesc4ArgNameAndIndex(pair.first, pair.second); const auto& nd_sbp = ctx->NdSbp4ArgNameAndIndex(pair.first, pair.second); *desc->mut_shape() = *JUST( diff --git a/oneflow/core/kernel/user_kernel.cpp b/oneflow/core/kernel/user_kernel.cpp index 06432f6ba26..2857c21e112 100644 --- a/oneflow/core/kernel/user_kernel.cpp +++ b/oneflow/core/kernel/user_kernel.cpp @@ -252,7 +252,12 @@ class UserKernelOpInferContext : public user_op::InferContext { return *const_cast(this)->TensorDesc4ArgNameAndIndex(arg_name, index); } - user_op::TensorDesc* OutputTensorDesc(const std::string& arg_name, int32_t index) override { + const user_op::TensorDesc& OutputTensorDesc(const std::string& arg_name, + int32_t index) const override { + return *const_cast(this)->TensorDesc4ArgNameAndIndex(arg_name, + index); + } + user_op::TensorDesc* MutOutputTensorDesc(const std::string& arg_name, int32_t index) override { return TensorDesc4ArgNameAndIndex(arg_name, index); } user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) { diff --git a/oneflow/core/operator/user_op.cpp b/oneflow/core/operator/user_op.cpp index bb94bfe86e6..2f51db099e2 100644 --- a/oneflow/core/operator/user_op.cpp +++ b/oneflow/core/operator/user_op.cpp @@ -149,7 +149,11 @@ class UserOpInferContext final : public user_op::InferContext { int32_t index) const override { return *const_cast(this)->TensorDesc4ArgNameAndIndex(arg_name, index); } - user_op::TensorDesc* OutputTensorDesc(const std::string& arg_name, int32_t index) override { + const user_op::TensorDesc& OutputTensorDesc(const std::string& arg_name, + int32_t index) const override { + return *const_cast(this)->TensorDesc4ArgNameAndIndex(arg_name, index); + } + user_op::TensorDesc* MutOutputTensorDesc(const std::string& arg_name, int32_t index) override { return TensorDesc4ArgNameAndIndex(arg_name, index); } user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) { @@ -616,11 +620,11 @@ Maybe UserOp::InferLogicalOutBlobDescs( JUST(val_->logical_tensor_desc_infer_fn(&infer_ctx)); for (const auto& pair : infer_ctx.outputs()) { BlobDesc* out_blob_desc = BlobDesc4BnInOp(GenRepeatedBn(pair.first, pair.second)); - user_op::TensorDesc* tensor_desc = infer_ctx.OutputTensorDesc(pair.first, pair.second); - out_blob_desc->set_data_type(tensor_desc->data_type()); - out_blob_desc->mut_shape() = tensor_desc->shape(); - out_blob_desc->mut_stride() = tensor_desc->stride(); - out_blob_desc->set_is_dynamic(tensor_desc->is_dynamic()); + const user_op::TensorDesc& tensor_desc = infer_ctx.OutputTensorDesc(pair.first, pair.second); + out_blob_desc->set_data_type(tensor_desc.data_type()); + out_blob_desc->mut_shape() = tensor_desc.shape(); + out_blob_desc->mut_stride() = tensor_desc.stride(); + out_blob_desc->set_is_dynamic(tensor_desc.is_dynamic()); } return Maybe::Ok(); } diff --git a/oneflow/user/kernels/conv_cudnn_kernels.cpp b/oneflow/user/kernels/conv_cudnn_kernels.cpp index e18f0dac968..5513ab66dad 100644 --- a/oneflow/user/kernels/conv_cudnn_kernels.cpp +++ b/oneflow/user/kernels/conv_cudnn_kernels.cpp @@ -221,11 +221,11 @@ class ConvGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphS const auto& in = ctx->InputTensorDesc("in", 0); \ if (in.shape().elem_cnt() == 0) return 0; \ const auto& weight = ctx->InputTensorDesc("weight", 0); \ - const auto* out = ctx->OutputTensorDesc("out", 0); \ + const auto& out = ctx->OutputTensorDesc("out", 0); \ const auto& cudnn_conf = \ Singleton::Get()->resource().cudnn_conf(); \ return InferTmpSizeWithCudnn( \ - &in, &weight, out, *ctx, cudnn_conf.has_cudnn_conv_force_fwd_algo(), \ + &in, &weight, &out, *ctx, cudnn_conf.has_cudnn_conv_force_fwd_algo(), \ cudnn_conf.cudnn_conv_force_fwd_algo()); \ }) @@ -300,12 +300,12 @@ class ConvDataGradGpuKernel final : public user_op::OpKernel, public user_op::Cu .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t { \ const auto& dy = ctx->InputTensorDesc("dy", 0); \ const auto& filter = ctx->InputTensorDesc("filter", 0); \ - const auto* dx = ctx->OutputTensorDesc("dx", 0); \ - if (dx->shape().elem_cnt() == 0) return 0; \ + const auto& dx = ctx->OutputTensorDesc("dx", 0); \ + if (dx.shape().elem_cnt() == 0) return 0; \ const auto& cudnn_conf = \ Singleton::Get()->resource().cudnn_conf(); \ return InferTmpSizeWithCudnn( \ - dx, &filter, &dy, *ctx, cudnn_conf.has_cudnn_conv_force_bwd_data_algo(), \ + &dx, &filter, &dy, *ctx, cudnn_conf.has_cudnn_conv_force_bwd_data_algo(), \ cudnn_conf.cudnn_conv_force_bwd_data_algo()); \ }) \ .SetInplaceProposalFn([](const user_op::InferContext& ctx, \ @@ -364,21 +364,21 @@ class ConvFilterGradGpuKernel final : public user_op::OpKernel, public user_op:: } }; -#define REGISTER_CONV_FILTER_GRAD_FLOATING_KERNEL(dtype) \ - REGISTER_USER_KERNEL("conv_filter_grad") \ - .SetCreateFn>() \ - .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ - && (user_op::HobDataType("dy", 0) == GetDataType::value)) \ - .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t { \ - const auto& dy = ctx->InputTensorDesc("dy", 0); \ - const auto& x = ctx->InputTensorDesc("x", 0); \ - if (x.shape().elem_cnt() == 0) return 0; \ - const auto* filter_diff = ctx->OutputTensorDesc("filter_diff", 0); \ - const auto& cudnn_conf = \ - Singleton::Get()->resource().cudnn_conf(); \ - return InferTmpSizeWithCudnn( \ - &x, filter_diff, &dy, *ctx, cudnn_conf.has_cudnn_conv_force_bwd_filter_algo(), \ - cudnn_conf.cudnn_conv_force_bwd_filter_algo()); \ +#define REGISTER_CONV_FILTER_GRAD_FLOATING_KERNEL(dtype) \ + REGISTER_USER_KERNEL("conv_filter_grad") \ + .SetCreateFn>() \ + .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ + && (user_op::HobDataType("dy", 0) == GetDataType::value)) \ + .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t { \ + const auto& dy = ctx->InputTensorDesc("dy", 0); \ + const auto& x = ctx->InputTensorDesc("x", 0); \ + if (x.shape().elem_cnt() == 0) return 0; \ + const auto& filter_diff = ctx->OutputTensorDesc("filter_diff", 0); \ + const auto& cudnn_conf = \ + Singleton::Get()->resource().cudnn_conf(); \ + return InferTmpSizeWithCudnn( \ + &x, &filter_diff, &dy, *ctx, cudnn_conf.has_cudnn_conv_force_bwd_filter_algo(), \ + cudnn_conf.cudnn_conv_force_bwd_filter_algo()); \ }) REGISTER_CONV_FILTER_GRAD_FLOATING_KERNEL(float); diff --git a/oneflow/user/kernels/conv_kernels.cpp b/oneflow/user/kernels/conv_kernels.cpp index e483340d44a..f97fbd2b0ae 100644 --- a/oneflow/user/kernels/conv_kernels.cpp +++ b/oneflow/user/kernels/conv_kernels.cpp @@ -570,7 +570,7 @@ class ConvCpuKernel final : public user_op::OpKernel { && ChannelsLastMatmulPrimitiveExists()) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t { \ size_t tmp_buffer_size = 0; \ - const auto& out_shape = ctx->OutputTensorDesc("out", 0)->shape(); \ + const auto& out_shape = ctx->OutputTensorDesc("out", 0).shape(); \ const auto& weight_shape = ctx->InputTensorDesc("weight", 0).shape(); \ \ int64_t idx_offset = IdxOffset(ctx->Attr("data_format")); \ @@ -748,7 +748,7 @@ class ConvFilterGradCpuKernel final : public user_op::OpKernel { .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t { \ size_t tmp_buffer_size = 0; \ const auto& out_diff_shape = ctx->InputTensorDesc("dy", 0).shape(); \ - const auto& weight_diff_shape = ctx->OutputTensorDesc("filter_diff", 0)->shape(); \ + const auto& weight_diff_shape = ctx->OutputTensorDesc("filter_diff", 0).shape(); \ \ int64_t idx_offset = IdxOffset(ctx->Attr("data_format")); \ tmp_buffer_size += \ diff --git a/oneflow/user/kernels/deconv_cudnn_kernel.cpp b/oneflow/user/kernels/deconv_cudnn_kernel.cpp index 35c1eec46ce..69938d43ed2 100644 --- a/oneflow/user/kernels/deconv_cudnn_kernel.cpp +++ b/oneflow/user/kernels/deconv_cudnn_kernel.cpp @@ -146,11 +146,11 @@ class DeConvGpuKernel final : public user_op::OpKernel { const auto& in = ctx->InputTensorDesc("in", 0); \ if (in.shape().elem_cnt() == 0) return 0; \ const auto& weight = ctx->InputTensorDesc("weight", 0); \ - const auto* out = ctx->OutputTensorDesc("out", 0); \ + const auto& out = ctx->OutputTensorDesc("out", 0); \ const auto& cudnn_conf = \ Singleton::Get()->resource().cudnn_conf(); \ return InferTmpSizeWithCudnn( \ - out, &weight, &in, *ctx, cudnn_conf.has_cudnn_conv_force_bwd_data_algo(), \ + &out, &weight, &in, *ctx, cudnn_conf.has_cudnn_conv_force_bwd_data_algo(), \ cudnn_conf.cudnn_conv_force_bwd_data_algo()); \ }) diff --git a/oneflow/user/kernels/fused_gru_cell_kernel.cu b/oneflow/user/kernels/fused_gru_cell_kernel.cu index 3e91268e939..a584282cf7d 100644 --- a/oneflow/user/kernels/fused_gru_cell_kernel.cu +++ b/oneflow/user/kernels/fused_gru_cell_kernel.cu @@ -459,7 +459,7 @@ REGISTER_USER_KERNEL("fused_gru_cell_grad") size_t tmp_bytes = 0; if (ctx->has_output("grad_input_bias", 0) && ctx->has_output("grad_hidden_bias", 0)) { const Shape& in_shape = ctx->InputTensorDesc("grad_hy", 0).shape(); - const Shape& out_shape = ctx->OutputTensorDesc("grad_input_bias", 0)->shape(); + const Shape& out_shape = ctx->OutputTensorDesc("grad_input_bias", 0).shape(); tmp_bytes = (2 * GetCudaAlignedSize(in_shape.elem_cnt() * 3 * sizeof(float)) + GetCudaAlignedSize(out_shape.elem_cnt() * sizeof(float))); } else { diff --git a/oneflow/user/kernels/fused_lstm_cell_kernel.cu b/oneflow/user/kernels/fused_lstm_cell_kernel.cu index 568ab44d482..e532becb2a4 100644 --- a/oneflow/user/kernels/fused_lstm_cell_kernel.cu +++ b/oneflow/user/kernels/fused_lstm_cell_kernel.cu @@ -492,7 +492,7 @@ REGISTER_USER_KERNEL("fused_lstm_cell_grad") size_t tmp_bytes = 0; if (ctx->has_output("grad_bias", 0)) { const Shape& in_shape = ctx->InputTensorDesc("workspace", 0).shape(); - const Shape& out_shape = ctx->OutputTensorDesc("grad_bias", 0)->shape(); + const Shape& out_shape = ctx->OutputTensorDesc("grad_bias", 0).shape(); tmp_bytes = (2 * GetCudaAlignedSize(in_shape.elem_cnt() * sizeof(float)) + GetCudaAlignedSize(out_shape.elem_cnt() * sizeof(float))); } else { diff --git a/oneflow/user/kernels/group_conv_kernel.cpp b/oneflow/user/kernels/group_conv_kernel.cpp index f85f221bb87..c3aa8dfab46 100644 --- a/oneflow/user/kernels/group_conv_kernel.cpp +++ b/oneflow/user/kernels/group_conv_kernel.cpp @@ -566,7 +566,7 @@ class ConvCpuKernel final : public user_op::OpKernel { && ChannelsLastMatmulPrimitiveExists()) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t { \ size_t tmp_buffer_size = 0; \ - const auto& out_shape = ctx->OutputTensorDesc("out", 0)->shape(); \ + const auto& out_shape = ctx->OutputTensorDesc("out", 0).shape(); \ const auto& weight_shape = ctx->InputTensorDesc("weight", 0).shape(); \ \ int64_t idx_offset = IdxOffset(ctx->Attr("data_format")); \ @@ -781,7 +781,7 @@ class ConvFilterGradCpuKernel final : public user_op::OpKernel { .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t { \ size_t tmp_buffer_size = 0; \ const auto& out_diff_shape = ctx->InputTensorDesc("dy", 0).shape(); \ - const auto& weight_diff_shape = ctx->OutputTensorDesc("filter_diff", 0)->shape(); \ + const auto& weight_diff_shape = ctx->OutputTensorDesc("filter_diff", 0).shape(); \ \ int64_t idx_offset = IdxOffset(ctx->Attr("data_format")); \ tmp_buffer_size += \ diff --git a/oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp b/oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp index f5649e42287..95518335ea6 100644 --- a/oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp +++ b/oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp @@ -252,9 +252,9 @@ class NcclLogical2DSameDim0AllGatherNoncontinuous final : public user_op::OpKern }; size_t Infer2DSameDim0AllGatherNoncontinuousKernelTmpBufferSize(user_op::InferContext* ctx) { - const user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); - return GetCudaAlignedSize(out_tensor->shape().elem_cnt() - * GetSizeOfDataType(out_tensor->data_type())); + const user_op::TensorDesc& out_tensor = ctx->OutputTensorDesc("out", 0); + return GetCudaAlignedSize(out_tensor.shape().elem_cnt() + * GetSizeOfDataType(out_tensor.data_type())); } template diff --git a/oneflow/user/kernels/nccl_logical_kernels.cpp b/oneflow/user/kernels/nccl_logical_kernels.cpp index a6287ef27a7..9fdfccbae4f 100644 --- a/oneflow/user/kernels/nccl_logical_kernels.cpp +++ b/oneflow/user/kernels/nccl_logical_kernels.cpp @@ -276,9 +276,9 @@ class NcclLogicalAllGatherNoncontinuous final : public user_op::OpKernel { }; size_t InferAllGatherNoncontinuousKernelTmpBufferSize(user_op::InferContext* ctx) { - const user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); - return GetCudaAlignedSize(out_tensor->shape().elem_cnt() - * GetSizeOfDataType(out_tensor->data_type())); + const user_op::TensorDesc& out_tensor = ctx->OutputTensorDesc("out", 0); + return GetCudaAlignedSize(out_tensor.shape().elem_cnt() + * GetSizeOfDataType(out_tensor.data_type())); } template @@ -348,9 +348,9 @@ class NcclLogicalReduceScatterNoncontinuous final : public user_op::OpKernel { }; size_t InferReduceScatterNoncontinuousKernelTmpBufferSize(user_op::InferContext* ctx) { - const user_op::TensorDesc* in_tensor = ctx->OutputTensorDesc("in", 0); - return GetCudaAlignedSize(in_tensor->shape().elem_cnt() - * GetSizeOfDataType(in_tensor->data_type())); + const user_op::TensorDesc& in_tensor = ctx->OutputTensorDesc("in", 0); + return GetCudaAlignedSize(in_tensor.shape().elem_cnt() + * GetSizeOfDataType(in_tensor.data_type())); } template diff --git a/oneflow/user/kernels/reduce_kernel.cpp b/oneflow/user/kernels/reduce_kernel.cpp index 106617d68b0..061420a8109 100644 --- a/oneflow/user/kernels/reduce_kernel.cpp +++ b/oneflow/user/kernels/reduce_kernel.cpp @@ -307,7 +307,7 @@ REGISTER_USER_KERNEL("reduce_sum") && ReduceMatmulNoTransAPrimitiveExists()) .SetInferTmpSizeFn([](user_op::InferContext* ctx) { const Shape& in_shape = ctx->InputTensorDesc("input_tensor", 0).shape(); - const Shape& out_shape = ctx->OutputTensorDesc("output_tensor", 0)->shape(); + const Shape& out_shape = ctx->OutputTensorDesc("output_tensor", 0).shape(); const auto& axis = RegularAxis(ctx->Attr>("axis")); bool is_axis_contiguous = false; int64_t outer_size = 0, inner_size = 0, reduce_size = 0; diff --git a/oneflow/user/kernels/reduce_like_kernels.cpp b/oneflow/user/kernels/reduce_like_kernels.cpp index df1bfc110cb..bf4c02714c9 100644 --- a/oneflow/user/kernels/reduce_like_kernels.cpp +++ b/oneflow/user/kernels/reduce_like_kernels.cpp @@ -231,7 +231,7 @@ REGISTER_USER_KERNEL("reduce_sum_like") && ReduceMatmulNoTransAPrimitiveExists()) .SetInferTmpSizeFn([](user_op::InferContext* ctx) { const Shape& in_shape = ctx->InputTensorDesc("x", 0).shape(); - const Shape& out_shape = ctx->OutputTensorDesc("y", 0)->shape(); + const Shape& out_shape = ctx->OutputTensorDesc("y", 0).shape(); const auto& axis = RegularAxis(ctx->Attr>("axis")); if (axis.empty()) { size_t tmp_bytes = 0; diff --git a/oneflow/user/kernels/stateful_opkernel.cpp b/oneflow/user/kernels/stateful_opkernel.cpp index a7c47107e23..e2f8de6809a 100644 --- a/oneflow/user/kernels/stateful_opkernel.cpp +++ b/oneflow/user/kernels/stateful_opkernel.cpp @@ -161,9 +161,12 @@ class UserOpInferContextHelper final { const std::string& arg_name, int32_t index) const { return *CHECK_NOTNULL(TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index)); } - - user_op::TensorDesc* OutputTensorDesc(eager::CallContext* call_ctx, const std::string& arg_name, - int32_t index) const { + const user_op::TensorDesc& OutputTensorDesc(eager::CallContext* call_ctx, + const std::string& arg_name, int32_t index) const { + return *CHECK_NOTNULL(TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index)); + } + user_op::TensorDesc* MutOutputTensorDesc(eager::CallContext* call_ctx, + const std::string& arg_name, int32_t index) const { return TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index); } user_op::TensorDesc* TensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, @@ -339,9 +342,13 @@ class UserOpInferContext : public user_op::InferContext { int32_t index) const override { return helper_->InputTensorDesc(call_ctx_, arg_name, index); } - user_op::TensorDesc* OutputTensorDesc(const std::string& arg_name, int32_t index) override { + const user_op::TensorDesc& OutputTensorDesc(const std::string& arg_name, + int32_t index) const override { return helper_->OutputTensorDesc(call_ctx_, arg_name, index); } + user_op::TensorDesc* MutOutputTensorDesc(const std::string& arg_name, int32_t index) override { + return helper_->MutOutputTensorDesc(call_ctx_, arg_name, index); + } user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) { return helper_->TensorDesc4ArgNameAndIndex(call_ctx_, arg_name, index); } diff --git a/oneflow/user/ops/add_n_op.cpp b/oneflow/user/ops/add_n_op.cpp index c135a845c4e..8b1f6e55b30 100644 --- a/oneflow/user/ops/add_n_op.cpp +++ b/oneflow/user/ops/add_n_op.cpp @@ -19,7 +19,7 @@ limitations under the License. namespace oneflow { /* static */ Maybe AddNOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const auto& in_0 = ctx->InputTensorDesc("in", 0); - auto* out = ctx->OutputTensorDesc("out", 0); + auto* out = ctx->MutOutputTensorDesc("out", 0); CHECK_NOTNULL_OR_RETURN(out); // NOLINT(maybe-need-error-msg) for (const auto& pair : ctx->inputs()) { const auto& cur_in = ctx->InputTensorDesc(pair.first, pair.second); @@ -50,7 +50,7 @@ namespace oneflow { /* static */ Maybe AddNOp::InferDataType(user_op::InferContext* ctx) { const auto& in_0 = ctx->InputTensorDesc("in", 0); - auto* out = ctx->OutputTensorDesc("out", 0); + auto* out = ctx->MutOutputTensorDesc("out", 0); CHECK_NOTNULL_OR_RETURN(out); // NOLINT(maybe-need-error-msg) for (const auto& pair : ctx->inputs()) { const auto& cur_in = ctx->InputTensorDesc(pair.first, pair.second); diff --git a/oneflow/user/ops/affine_grid_op.cpp b/oneflow/user/ops/affine_grid_op.cpp index 1826c039c63..fa2f83c89ed 100644 --- a/oneflow/user/ops/affine_grid_op.cpp +++ b/oneflow/user/ops/affine_grid_op.cpp @@ -48,7 +48,7 @@ Maybe CheckAttr_(const user_op::UserOpDefWrapper& def, /* static */ Maybe AffineGridOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& theta = ctx->InputTensorDesc("theta", 0); - user_op::TensorDesc* grid = ctx->OutputTensorDesc("grid", 0); + user_op::TensorDesc* grid = ctx->MutOutputTensorDesc("grid", 0); const Shape& size = ctx->Attr("size"); // Only support 2D or 3D affine grid with NCHW layout // For 2D grid: theta = { N, 2, 3 }, @@ -85,7 +85,7 @@ Maybe CheckAttr_(const user_op::UserOpDefWrapper& def, /*static*/ Maybe AffineGridOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& theta = ctx->InputTensorDesc("theta", 0); - user_op::TensorDesc* grid = ctx->OutputTensorDesc("grid", 0); + user_op::TensorDesc* grid = ctx->MutOutputTensorDesc("grid", 0); const Shape& size = ctx->Attr("size"); // Only support 2D or 3D affine grid with NCHW layout // For 2D grid: theta = { N, 2, 3 }, @@ -153,9 +153,9 @@ Maybe CheckAttr_(const user_op::UserOpDefWrapper& def, const user_op::TensorDesc& dgrid = ctx->InputTensorDesc("dgrid", 0); const Shape& size = ctx->Attr("size"); if (size.NumAxes() == 4) { - *(ctx->OutputTensorDesc("dtheta", 0)->mut_shape()) = {dgrid.shape().At(0), 2, 3}; + *(ctx->MutOutputTensorDesc("dtheta", 0)->mut_shape()) = {dgrid.shape().At(0), 2, 3}; } else if (size.NumAxes() == 5) { - *(ctx->OutputTensorDesc("dtheta", 0)->mut_shape()) = {dgrid.shape().At(0), 3, 4}; + *(ctx->MutOutputTensorDesc("dtheta", 0)->mut_shape()) = {dgrid.shape().At(0), 3, 4}; } else { CHECK_OR_RETURN(false) << "size MUST be 4D or 5D"; } diff --git a/oneflow/user/ops/amp_white_identity_op.cpp b/oneflow/user/ops/amp_white_identity_op.cpp index 46a90141d8d..449a867f473 100644 --- a/oneflow/user/ops/amp_white_identity_op.cpp +++ b/oneflow/user/ops/amp_white_identity_op.cpp @@ -20,7 +20,7 @@ namespace oneflow { /* static */ Maybe AmpWhiteIdentityOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); *out->mut_shape() = in.shape(); *out->mut_is_dynamic() = in.is_dynamic(); return Maybe::Ok(); @@ -41,7 +41,7 @@ namespace oneflow { /* static */ Maybe AmpWhiteIdentityOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); *out->mut_data_type() = in.data_type(); return Maybe::Ok(); } diff --git a/oneflow/user/ops/arg_where_op.cpp b/oneflow/user/ops/arg_where_op.cpp index 3ce31486a50..048d3c30be1 100644 --- a/oneflow/user/ops/arg_where_op.cpp +++ b/oneflow/user/ops/arg_where_op.cpp @@ -22,10 +22,10 @@ namespace { Maybe InferTensorDesc(user_op::InferContext* ctx) { const Shape& input_shape = ctx->InputShape("input", 0); - user_op::TensorDesc* output_desc = ctx->OutputTensorDesc("output", 0); + user_op::TensorDesc* output_desc = ctx->MutOutputTensorDesc("output", 0); *output_desc->mut_shape() = Shape({input_shape.elem_cnt(), input_shape.NumAxes()}); output_desc->set_is_dynamic(true); - user_op::TensorDesc* output_size_desc = ctx->OutputTensorDesc("output_size", 0); + user_op::TensorDesc* output_size_desc = ctx->MutOutputTensorDesc("output_size", 0); *output_size_desc->mut_shape() = Shape({1}); return Maybe::Ok(); } @@ -46,9 +46,9 @@ Maybe InferTensorDesc(user_op::InferContext* ctx) { /* static */ Maybe ArgwhereOp::InferDataType(user_op::InferContext* ctx) { const DataType dtype = ctx->Attr("dtype"); - user_op::TensorDesc* output_desc = ctx->OutputTensorDesc("output", 0); + user_op::TensorDesc* output_desc = ctx->MutOutputTensorDesc("output", 0); *output_desc->mut_data_type() = dtype; - user_op::TensorDesc* output_size_desc = ctx->OutputTensorDesc("output_size", 0); + user_op::TensorDesc* output_size_desc = ctx->MutOutputTensorDesc("output_size", 0); *output_size_desc->mut_data_type() = dtype; return Maybe::Ok(); } diff --git a/oneflow/user/ops/as_strided_op.cpp b/oneflow/user/ops/as_strided_op.cpp index 45ae191a59d..5f04be87dff 100644 --- a/oneflow/user/ops/as_strided_op.cpp +++ b/oneflow/user/ops/as_strided_op.cpp @@ -24,7 +24,7 @@ namespace oneflow { CHECK_EQ_OR_RETURN(size.size(), stride.size()) << "mismatch in length of strides and shape"; DimVector out_vec; out_vec.insert(out_vec.end(), size.cbegin(), size.cend()); - user_op::TensorDesc* output_desc = ctx->OutputTensorDesc("output", 0); + user_op::TensorDesc* output_desc = ctx->MutOutputTensorDesc("output", 0); *output_desc->mut_shape() = Shape(out_vec); return Maybe::Ok(); } @@ -42,7 +42,7 @@ namespace oneflow { /* static */ auto AsStridedGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) -> Maybe { const Shape& input_shape = ctx->InputShape("input", 0); - user_op::TensorDesc* dx_desc = ctx->OutputTensorDesc("dx", 0); + user_op::TensorDesc* dx_desc = ctx->MutOutputTensorDesc("dx", 0); *dx_desc->mut_shape() = input_shape; return Maybe::Ok(); } diff --git a/oneflow/user/ops/avg_pool_op.cpp b/oneflow/user/ops/avg_pool_op.cpp index e59ebdc5609..4a7548797d7 100644 --- a/oneflow/user/ops/avg_pool_op.cpp +++ b/oneflow/user/ops/avg_pool_op.cpp @@ -55,7 +55,7 @@ TensorDescInferFn AvgPoolMakeForwardTensorDescInferFn(const int32_t dim) { const AvgPoolParams3D params_3d(dim, x_shape, data_format, padding, kernel_size, stride, ceil_mode, count_include_pad, divisor_override); - user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); + user_op::TensorDesc* y_desc = ctx->MutOutputTensorDesc("y", 0); *y_desc = ctx->InputTensorDesc("x", 0); *y_desc->mut_shape() = params_3d.GetYShape(); @@ -107,7 +107,7 @@ GenBackwardOpConfFn AvgPoolMakeBackwardOpConfFn(const int32_t dim) { } Maybe BackwardTensorDescInferFn(user_op::InferContext* ctx) { - *ctx->OutputTensorDesc("dx", 0) = ctx->InputTensorDesc("x", 0); + *ctx->MutOutputTensorDesc("dx", 0) = ctx->InputTensorDesc("x", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/batch_gather_op.cpp b/oneflow/user/ops/batch_gather_op.cpp index f61efbc61b6..2c33db5769a 100644 --- a/oneflow/user/ops/batch_gather_op.cpp +++ b/oneflow/user/ops/batch_gather_op.cpp @@ -28,7 +28,7 @@ namespace oneflow { << Error::RuntimeError() << "The dimension of the indices tensor should be greater than zero, " << "but got " << indices.shape().NumAxes(); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); CHECK_LE_OR_RETURN(indices.shape().dim_vec().size(), in.shape().dim_vec().size()) << Error::RuntimeError() << "The dimension of the input tensor should be greater than or equal to the dimension of " @@ -97,7 +97,7 @@ namespace oneflow { CHECK_OR_RETURN(IsIndexDataType(indices.data_type())) << Error::TypeError() << "The dtype of the indices tensor must be int32 or int64"; const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); *out->mut_data_type() = in.data_type(); return Maybe::Ok(); } diff --git a/oneflow/user/ops/bernoulli_op.cpp b/oneflow/user/ops/bernoulli_op.cpp index 3068b83fd0c..a0aabe496c2 100644 --- a/oneflow/user/ops/bernoulli_op.cpp +++ b/oneflow/user/ops/bernoulli_op.cpp @@ -19,7 +19,7 @@ limitations under the License. namespace oneflow { /* static */ Maybe BernoulliOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc("out", 0); const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); *out_tensor->mut_shape() = in_tensor.shape(); return Maybe::Ok(); @@ -38,7 +38,7 @@ namespace oneflow { } /* static */ Maybe BernoulliOp::InferDataType(user_op::InferContext* ctx) { - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc("out", 0); *out_tensor->mut_data_type() = ctx->Attr("dtype"); return Maybe::Ok(); } diff --git a/oneflow/user/ops/binary_cross_entropy_op.cpp b/oneflow/user/ops/binary_cross_entropy_op.cpp index 1b3d8f60416..f896e4f29ef 100644 --- a/oneflow/user/ops/binary_cross_entropy_op.cpp +++ b/oneflow/user/ops/binary_cross_entropy_op.cpp @@ -33,7 +33,7 @@ Maybe InferTensorDescFn_(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(weight_desc.shape(), input_desc.shape()); } - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); *out_desc->mut_is_dynamic() = input_desc.is_dynamic(); *out_desc->mut_shape() = input_desc.shape(); @@ -67,7 +67,7 @@ Maybe InferGradTensorDescFn(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(weight_desc.shape(), input_desc.shape()); } - user_op::TensorDesc* dx_desc = ctx->OutputTensorDesc("dx", 0); + user_op::TensorDesc* dx_desc = ctx->MutOutputTensorDesc("dx", 0); *dx_desc->mut_is_dynamic() = input_desc.is_dynamic(); *dx_desc->mut_shape() = input_desc.shape(); diff --git a/oneflow/user/ops/binary_cross_entropy_with_logits_op.cpp b/oneflow/user/ops/binary_cross_entropy_with_logits_op.cpp index 46eb05e33de..5bb7f863f08 100644 --- a/oneflow/user/ops/binary_cross_entropy_with_logits_op.cpp +++ b/oneflow/user/ops/binary_cross_entropy_with_logits_op.cpp @@ -36,7 +36,7 @@ Maybe InferTensorDescFn(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(pos_weight_desc.shape(), Shape({input_desc.shape().At(input_desc.shape().NumAxes() - 1)})); } - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); *out_desc->mut_is_dynamic() = input_desc.is_dynamic(); *out_desc->mut_shape() = input_desc.shape(); @@ -78,7 +78,7 @@ Maybe InferGradTensorDescFn(user_op::InferContext* ctx) { Shape({input_desc.shape().At(input_desc.shape().NumAxes() - 1)})); } - user_op::TensorDesc* dx_desc = ctx->OutputTensorDesc("dx", 0); + user_op::TensorDesc* dx_desc = ctx->MutOutputTensorDesc("dx", 0); *dx_desc->mut_is_dynamic() = input_desc.is_dynamic(); *dx_desc->mut_shape() = input_desc.shape(); diff --git a/oneflow/user/ops/binary_cross_entropy_with_logits_reduce_mean_op.cpp b/oneflow/user/ops/binary_cross_entropy_with_logits_reduce_mean_op.cpp index 273219e85ea..73276864ac5 100644 --- a/oneflow/user/ops/binary_cross_entropy_with_logits_reduce_mean_op.cpp +++ b/oneflow/user/ops/binary_cross_entropy_with_logits_reduce_mean_op.cpp @@ -26,7 +26,7 @@ Maybe InferTensorDescFn(user_op::InferContext* ctx) { const auto& target_desc = ctx->InputTensorDesc("target", 0); CHECK_EQ_OR_RETURN(input_desc.shape(), target_desc.shape()) << "Input shape should be equal to Target shape. "; - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); *out_desc->mut_is_dynamic() = false; *out_desc->mut_shape() = Shape({}); return Maybe::Ok(); @@ -47,7 +47,7 @@ Maybe InferGradTensorDescFn(user_op::InferContext* ctx) { const auto& target_desc = ctx->InputTensorDesc("target", 0); CHECK_EQ_OR_RETURN(input_desc.shape(), target_desc.shape()) << "Input shape should be equal to Target shape. "; - user_op::TensorDesc* dx_desc = ctx->OutputTensorDesc("dx", 0); + user_op::TensorDesc* dx_desc = ctx->MutOutputTensorDesc("dx", 0); *dx_desc->mut_is_dynamic() = false; *dx_desc->mut_shape() = input_desc.shape(); return Maybe::Ok(); diff --git a/oneflow/user/ops/cast_like_op.cpp b/oneflow/user/ops/cast_like_op.cpp index b61d6722d64..ed310ff2dd5 100644 --- a/oneflow/user/ops/cast_like_op.cpp +++ b/oneflow/user/ops/cast_like_op.cpp @@ -65,7 +65,7 @@ namespace oneflow { /* static */ Maybe CastLikeOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& dtype_like_tensor_desc = ctx->InputTensorDesc("dtype_like", 0); - user_op::TensorDesc* output_tensor_desc = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* output_tensor_desc = ctx->MutOutputTensorDesc("out", 0); *output_tensor_desc->mut_data_type() = dtype_like_tensor_desc.data_type(); return Maybe::Ok(); } diff --git a/oneflow/user/ops/cast_op.cpp b/oneflow/user/ops/cast_op.cpp index 0cbcd03ce5f..95b5f0f14cb 100644 --- a/oneflow/user/ops/cast_op.cpp +++ b/oneflow/user/ops/cast_op.cpp @@ -38,7 +38,7 @@ Maybe> MakeCastStream(const Symbol& in_device, /* static */ Maybe CastOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& input_tensor_desc = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* output_tensor_desc = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* output_tensor_desc = ctx->MutOutputTensorDesc("out", 0); *output_tensor_desc->mut_shape() = input_tensor_desc.shape(); *output_tensor_desc->mut_stride() = input_tensor_desc.stride(); // output's stride should consistent with input's @@ -60,7 +60,7 @@ Maybe> MakeCastStream(const Symbol& in_device, } /* static */ Maybe CastOp::InferDataType(user_op::InferContext* ctx) { - user_op::TensorDesc* output_tensor_desc = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* output_tensor_desc = ctx->MutOutputTensorDesc("out", 0); DataType* dtype = output_tensor_desc->mut_data_type(); *dtype = ctx->Attr("dtype"); return Maybe::Ok(); diff --git a/oneflow/user/ops/cast_to_static_shape_op.cpp b/oneflow/user/ops/cast_to_static_shape_op.cpp index 2b73703db8e..d37126dacf9 100644 --- a/oneflow/user/ops/cast_to_static_shape_op.cpp +++ b/oneflow/user/ops/cast_to_static_shape_op.cpp @@ -20,7 +20,7 @@ namespace oneflow { /* static */ Maybe CastToStaticShapeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& input_desc = ctx->InputTensorDesc("input", 0); - user_op::TensorDesc* output_desc = ctx->OutputTensorDesc("output", 0); + user_op::TensorDesc* output_desc = ctx->MutOutputTensorDesc("output", 0); *output_desc->mut_shape() = input_desc.shape(); output_desc->set_is_dynamic(false); return Maybe::Ok(); diff --git a/oneflow/user/ops/coco_reader_op.cpp b/oneflow/user/ops/coco_reader_op.cpp index 9cea9d8c168..55ddb016a59 100644 --- a/oneflow/user/ops/coco_reader_op.cpp +++ b/oneflow/user/ops/coco_reader_op.cpp @@ -20,19 +20,19 @@ namespace oneflow { /* static */ Maybe COCOReaderOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { int64_t batch_size = ctx->Attr("batch_size"); - user_op::TensorDesc* image_desc = ctx->OutputTensorDesc("image", 0); + user_op::TensorDesc* image_desc = ctx->MutOutputTensorDesc("image", 0); *image_desc->mut_shape() = Shape({batch_size}); - user_op::TensorDesc* image_id_desc = ctx->OutputTensorDesc("image_id", 0); + user_op::TensorDesc* image_id_desc = ctx->MutOutputTensorDesc("image_id", 0); *image_id_desc->mut_shape() = Shape({batch_size}); - user_op::TensorDesc* image_size_desc = ctx->OutputTensorDesc("image_size", 0); + user_op::TensorDesc* image_size_desc = ctx->MutOutputTensorDesc("image_size", 0); *image_size_desc->mut_shape() = Shape({batch_size, 2}); - user_op::TensorDesc* bbox_desc = ctx->OutputTensorDesc("gt_bbox", 0); + user_op::TensorDesc* bbox_desc = ctx->MutOutputTensorDesc("gt_bbox", 0); *bbox_desc->mut_shape() = Shape({batch_size}); - user_op::TensorDesc* label_desc = ctx->OutputTensorDesc("gt_label", 0); + user_op::TensorDesc* label_desc = ctx->MutOutputTensorDesc("gt_label", 0); *label_desc->mut_shape() = Shape({batch_size}); - user_op::TensorDesc* segm_desc = ctx->OutputTensorDesc("gt_segm", 0); + user_op::TensorDesc* segm_desc = ctx->MutOutputTensorDesc("gt_segm", 0); *segm_desc->mut_shape() = Shape({batch_size}); - user_op::TensorDesc* segm_index_desc = ctx->OutputTensorDesc("gt_segm_index", 0); + user_op::TensorDesc* segm_index_desc = ctx->MutOutputTensorDesc("gt_segm_index", 0); *segm_index_desc->mut_shape() = Shape({batch_size}); return Maybe::Ok(); } @@ -59,19 +59,19 @@ namespace oneflow { device_batch_size /= split_num; } - user_op::TensorDesc* image_desc = ctx->OutputTensorDesc("image", 0); + user_op::TensorDesc* image_desc = ctx->MutOutputTensorDesc("image", 0); *image_desc->mut_shape() = Shape({device_batch_size}); - user_op::TensorDesc* image_id_desc = ctx->OutputTensorDesc("image_id", 0); + user_op::TensorDesc* image_id_desc = ctx->MutOutputTensorDesc("image_id", 0); *image_id_desc->mut_shape() = Shape({device_batch_size}); - user_op::TensorDesc* image_size_desc = ctx->OutputTensorDesc("image_size", 0); + user_op::TensorDesc* image_size_desc = ctx->MutOutputTensorDesc("image_size", 0); *image_size_desc->mut_shape() = Shape({device_batch_size, 2}); - user_op::TensorDesc* bbox_desc = ctx->OutputTensorDesc("gt_bbox", 0); + user_op::TensorDesc* bbox_desc = ctx->MutOutputTensorDesc("gt_bbox", 0); *bbox_desc->mut_shape() = Shape({device_batch_size}); - user_op::TensorDesc* label_desc = ctx->OutputTensorDesc("gt_label", 0); + user_op::TensorDesc* label_desc = ctx->MutOutputTensorDesc("gt_label", 0); *label_desc->mut_shape() = Shape({device_batch_size}); - user_op::TensorDesc* segm_desc = ctx->OutputTensorDesc("gt_segm", 0); + user_op::TensorDesc* segm_desc = ctx->MutOutputTensorDesc("gt_segm", 0); *segm_desc->mut_shape() = Shape({device_batch_size}); - user_op::TensorDesc* segm_index_desc = ctx->OutputTensorDesc("gt_segm_index", 0); + user_op::TensorDesc* segm_index_desc = ctx->MutOutputTensorDesc("gt_segm_index", 0); *segm_index_desc->mut_shape() = Shape({device_batch_size}); return Maybe::Ok(); } @@ -120,19 +120,19 @@ namespace oneflow { } /* static */ Maybe COCOReaderOp::InferDataType(user_op::InferContext* ctx) { - user_op::TensorDesc* image_desc = ctx->OutputTensorDesc("image", 0); + user_op::TensorDesc* image_desc = ctx->MutOutputTensorDesc("image", 0); *image_desc->mut_data_type() = DataType::kTensorBuffer; - user_op::TensorDesc* image_id_desc = ctx->OutputTensorDesc("image_id", 0); + user_op::TensorDesc* image_id_desc = ctx->MutOutputTensorDesc("image_id", 0); *image_id_desc->mut_data_type() = DataType::kInt64; - user_op::TensorDesc* image_size_desc = ctx->OutputTensorDesc("image_size", 0); + user_op::TensorDesc* image_size_desc = ctx->MutOutputTensorDesc("image_size", 0); *image_size_desc->mut_data_type() = DataType::kInt32; - user_op::TensorDesc* bbox_desc = ctx->OutputTensorDesc("gt_bbox", 0); + user_op::TensorDesc* bbox_desc = ctx->MutOutputTensorDesc("gt_bbox", 0); *bbox_desc->mut_data_type() = DataType::kTensorBuffer; - user_op::TensorDesc* label_desc = ctx->OutputTensorDesc("gt_label", 0); + user_op::TensorDesc* label_desc = ctx->MutOutputTensorDesc("gt_label", 0); *label_desc->mut_data_type() = DataType::kTensorBuffer; - user_op::TensorDesc* segm_desc = ctx->OutputTensorDesc("gt_segm", 0); + user_op::TensorDesc* segm_desc = ctx->MutOutputTensorDesc("gt_segm", 0); *segm_desc->mut_data_type() = DataType::kTensorBuffer; - user_op::TensorDesc* segm_index_desc = ctx->OutputTensorDesc("gt_segm_index", 0); + user_op::TensorDesc* segm_index_desc = ctx->MutOutputTensorDesc("gt_segm_index", 0); *segm_index_desc->mut_data_type() = DataType::kTensorBuffer; return Maybe::Ok(); } diff --git a/oneflow/user/ops/combined_margin_loss_op.cpp b/oneflow/user/ops/combined_margin_loss_op.cpp index e9efad838ef..8e0206758a5 100644 --- a/oneflow/user/ops/combined_margin_loss_op.cpp +++ b/oneflow/user/ops/combined_margin_loss_op.cpp @@ -21,7 +21,7 @@ namespace oneflow { /* static */ Maybe CombinedMarginLossOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); const user_op::TensorDesc& label = ctx->InputTensorDesc("label", 0); - user_op::TensorDesc* theta = ctx->OutputTensorDesc("theta", 0); + user_op::TensorDesc* theta = ctx->MutOutputTensorDesc("theta", 0); CHECK_EQ_OR_RETURN(label.shape().At(0), x.shape().At(0)); CHECK_GE_OR_RETURN(x.shape().NumAxes(), 2); *ctx->MutOutputShape("y", 0) = ctx->InputShape("x", 0); diff --git a/oneflow/user/ops/concat_op.cpp b/oneflow/user/ops/concat_op.cpp index b8d5e8782e5..f6262a2f3ca 100644 --- a/oneflow/user/ops/concat_op.cpp +++ b/oneflow/user/ops/concat_op.cpp @@ -72,7 +72,7 @@ Maybe GenGradOp(const user_op::UserOpWrapper& op, const user_op::AddOpFn& } } - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); const int64_t max_dim_size = ctx->Attr("max_dim_size"); CHECK_LE_OR_RETURN(out_dim_vec.at(axis), max_dim_size); if (dynamic_dim_size == 0) { @@ -107,7 +107,7 @@ Maybe GenGradOp(const user_op::UserOpWrapper& op, const user_op::AddOpFn& ctx->InputTensorDesc(in_arg_pair.first, in_arg_pair.second); CHECK_EQ_OR_RETURN(in_desc.data_type(), first_in_desc.data_type()); } - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); *out_desc->mut_data_type() = first_in_desc.data_type(); return Maybe::Ok(); } diff --git a/oneflow/user/ops/conv_op.cpp b/oneflow/user/ops/conv_op.cpp index 9df06829a42..59b6f60c782 100644 --- a/oneflow/user/ops/conv_op.cpp +++ b/oneflow/user/ops/conv_op.cpp @@ -39,7 +39,7 @@ Maybe InferTensorDesc4Conv(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(NDims, strides.size()); CHECK_EQ_OR_RETURN(NDims, padding_before.size()); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); DimVector out_shape(NDims + 2); out_shape.at(0) = in.shape().At(0); const size_t c_dim = data_format == "channels_first" ? 1 : NDims + 1; @@ -378,7 +378,7 @@ Maybe GenerateBackwardOpConf4Conv(const user_op::UserOpWrapper& op, user_o filter_diff_dim_vec.emplace_back(x.shape().dim_vec().back() / groups); } - user_op::TensorDesc* filter_diff = ctx->OutputTensorDesc("filter_diff", 0); + user_op::TensorDesc* filter_diff = ctx->MutOutputTensorDesc("filter_diff", 0); *filter_diff->mut_shape() = Shape(filter_diff_dim_vec); filter_diff->set_is_dynamic(false); @@ -407,14 +407,14 @@ Maybe GenerateBackwardOpConf4Conv(const user_op::UserOpWrapper& op, user_o const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); CHECK_EQ_OR_RETURN(x.data_type(), dy.data_type()); - user_op::TensorDesc* filter_diff = ctx->OutputTensorDesc("filter_diff", 0); + user_op::TensorDesc* filter_diff = ctx->MutOutputTensorDesc("filter_diff", 0); *filter_diff->mut_data_type() = x.data_type(); return Maybe::Ok(); } /* static */ Maybe ConvBiasGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); - user_op::TensorDesc* bias_diff = ctx->OutputTensorDesc("bias_diff", 0); + user_op::TensorDesc* bias_diff = ctx->MutOutputTensorDesc("bias_diff", 0); int32_t num_spatial_dims = ctx->Attr("num_spatial_dims"); std::string data_format = ctx->Attr("data_format"); @@ -456,7 +456,7 @@ Maybe GenerateBackwardOpConf4Conv(const user_op::UserOpWrapper& op, user_o /* static */ Maybe ConvBiasGradOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); - user_op::TensorDesc* bias_diff = ctx->OutputTensorDesc("bias_diff", 0); + user_op::TensorDesc* bias_diff = ctx->MutOutputTensorDesc("bias_diff", 0); *bias_diff->mut_data_type() = dy.data_type(); return Maybe::Ok(); } diff --git a/oneflow/user/ops/count_not_finite_op.cpp b/oneflow/user/ops/count_not_finite_op.cpp index 20e752a0b2c..8b8dbfc94de 100644 --- a/oneflow/user/ops/count_not_finite_op.cpp +++ b/oneflow/user/ops/count_not_finite_op.cpp @@ -19,7 +19,7 @@ limitations under the License. namespace oneflow { /* static */ Maybe CountNotFiniteOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); + user_op::TensorDesc* y_desc = ctx->MutOutputTensorDesc("y", 0); *y_desc->mut_shape() = Shape({1}); return Maybe::Ok(); } @@ -37,13 +37,13 @@ namespace oneflow { } /* static */ Maybe CountNotFiniteOp::InferDataType(user_op::InferContext* ctx) { - user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); + user_op::TensorDesc* y_desc = ctx->MutOutputTensorDesc("y", 0); *y_desc->mut_data_type() = DataType::kInt64; return Maybe::Ok(); } /* static */ Maybe MultiCountNotFiniteOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); + user_op::TensorDesc* y_desc = ctx->MutOutputTensorDesc("y", 0); *y_desc->mut_shape() = Shape({1}); return Maybe::Ok(); } @@ -70,7 +70,7 @@ namespace oneflow { const user_op::TensorDesc& x_desc = ctx->InputTensorDesc(in_arg_pair.first, in_arg_pair.second); CHECK_EQ_OR_RETURN(x_desc.data_type(), first_x_desc.data_type()); } - user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); + user_op::TensorDesc* y_desc = ctx->MutOutputTensorDesc("y", 0); *y_desc->mut_data_type() = DataType::kInt64; return Maybe::Ok(); } diff --git a/oneflow/user/ops/cublas_bias_add_relu_matmul_grad_op.cpp b/oneflow/user/ops/cublas_bias_add_relu_matmul_grad_op.cpp index 0114b96336a..b7f10237d79 100644 --- a/oneflow/user/ops/cublas_bias_add_relu_matmul_grad_op.cpp +++ b/oneflow/user/ops/cublas_bias_add_relu_matmul_grad_op.cpp @@ -38,8 +38,8 @@ Maybe InferDataType4MatmulBackward(user_op::InferContext* ctx) { const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); CHECK_EQ_OR_RETURN(weight_desc.data_type(), dy_desc.data_type()); - user_op::TensorDesc* d_grad_desc = ctx->OutputTensorDesc("d_grad", 0); - user_op::TensorDesc* d_bias_desc = ctx->OutputTensorDesc("d_bias", 0); + user_op::TensorDesc* d_grad_desc = ctx->MutOutputTensorDesc("d_grad", 0); + user_op::TensorDesc* d_bias_desc = ctx->MutOutputTensorDesc("d_bias", 0); *d_grad_desc->mut_data_type() = dy_desc.data_type(); *d_bias_desc->mut_data_type() = dy_desc.data_type(); diff --git a/oneflow/user/ops/cublas_fused_matmul_bias_add_grad_op.cpp b/oneflow/user/ops/cublas_fused_matmul_bias_add_grad_op.cpp index 58e9b5e6912..3af1973d9ed 100644 --- a/oneflow/user/ops/cublas_fused_matmul_bias_add_grad_op.cpp +++ b/oneflow/user/ops/cublas_fused_matmul_bias_add_grad_op.cpp @@ -47,8 +47,8 @@ Maybe InferDataType4MatmulBiasAddBackward(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(x_desc.data_type(), dy_desc.data_type()) << "x's datatype should be the same as y's datatype"; - user_op::TensorDesc* w_grad_desc = ctx->OutputTensorDesc("w_grad", 0); - user_op::TensorDesc* b_grad_desc = ctx->OutputTensorDesc("b_grad", 0); + user_op::TensorDesc* w_grad_desc = ctx->MutOutputTensorDesc("w_grad", 0); + user_op::TensorDesc* b_grad_desc = ctx->MutOutputTensorDesc("b_grad", 0); *w_grad_desc->mut_data_type() = dy_desc.data_type(); *b_grad_desc->mut_data_type() = dy_desc.data_type(); diff --git a/oneflow/user/ops/cublas_fused_mlp_op.cpp b/oneflow/user/ops/cublas_fused_mlp_op.cpp index 9369a0e303c..4cab160770c 100644 --- a/oneflow/user/ops/cublas_fused_mlp_op.cpp +++ b/oneflow/user/ops/cublas_fused_mlp_op.cpp @@ -83,16 +83,16 @@ Maybe InferDataType4Matmul(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(in_desc.data_type(), first_in_desc.data_type()); } - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); *out_desc->mut_data_type() = first_in_desc.data_type(); for (int32_t i = 0; i < ctx->output_size("hidden"); i++) { - user_op::TensorDesc* hidden_desc = ctx->OutputTensorDesc("hidden", i); + user_op::TensorDesc* hidden_desc = ctx->MutOutputTensorDesc("hidden", i); *hidden_desc->mut_data_type() = first_in_desc.data_type(); } for (int32_t i = 0; i < ctx->output_size("cublas_aux"); i++) { - user_op::TensorDesc* aux_desc = ctx->OutputTensorDesc("cublas_aux", i); + user_op::TensorDesc* aux_desc = ctx->MutOutputTensorDesc("cublas_aux", i); *aux_desc->mut_data_type() = DataType::kInt32; } diff --git a/oneflow/user/ops/deconv_op.cpp b/oneflow/user/ops/deconv_op.cpp index fe943945b2a..cb7ebd4d2f7 100644 --- a/oneflow/user/ops/deconv_op.cpp +++ b/oneflow/user/ops/deconv_op.cpp @@ -42,7 +42,7 @@ Maybe InferTensorDesc4DeConv(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(NDims, strides.size()); CHECK_EQ_OR_RETURN(NDims, output_padding.size()); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); DimVector out_shape(NDims + 2); out_shape.at(0) = in.shape().At(0); const size_t c_dim = data_format == "channels_first" ? 1 : NDims + 1; diff --git a/oneflow/user/ops/diag_op.cpp b/oneflow/user/ops/diag_op.cpp index 624b29a07c5..fceb4ba538c 100644 --- a/oneflow/user/ops/diag_op.cpp +++ b/oneflow/user/ops/diag_op.cpp @@ -41,7 +41,7 @@ namespace oneflow { CHECK_GE_OR_RETURN(out_dim_vec[0], 0); // NOLINT } - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); out_desc->set_is_dynamic(false); *out_desc->mut_shape() = Shape(out_dim_vec); return Maybe::Ok(); @@ -64,7 +64,7 @@ namespace oneflow { /* static */ Maybe DiagGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); const Shape& in_shape = in.shape(); - user_op::TensorDesc* dx_desc = ctx->OutputTensorDesc("dx", 0); + user_op::TensorDesc* dx_desc = ctx->MutOutputTensorDesc("dx", 0); *dx_desc->mut_shape() = Shape(in_shape.dim_vec()); return Maybe::Ok(); } diff --git a/oneflow/user/ops/diagonal_op.cpp b/oneflow/user/ops/diagonal_op.cpp index 2511e6717e5..4051c5a07ae 100644 --- a/oneflow/user/ops/diagonal_op.cpp +++ b/oneflow/user/ops/diagonal_op.cpp @@ -36,7 +36,7 @@ namespace oneflow { if (last_dim < 0) { last_dim = 0; } out_dim_vec.push_back(last_dim); - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); out_desc->set_is_dynamic(false); *out_desc->mut_shape() = Shape(out_dim_vec); return Maybe::Ok(); @@ -59,7 +59,7 @@ namespace oneflow { /* static */ Maybe DiagonalGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); const Shape& in_shape = in.shape(); - user_op::TensorDesc* dx_desc = ctx->OutputTensorDesc("dx", 0); + user_op::TensorDesc* dx_desc = ctx->MutOutputTensorDesc("dx", 0); *dx_desc->mut_shape() = Shape(in_shape.dim_vec()); return Maybe::Ok(); } diff --git a/oneflow/user/ops/dim_gather_op.cpp b/oneflow/user/ops/dim_gather_op.cpp index 4e9c23b663b..0b387864c44 100644 --- a/oneflow/user/ops/dim_gather_op.cpp +++ b/oneflow/user/ops/dim_gather_op.cpp @@ -37,7 +37,7 @@ namespace oneflow { CHECK_EQ_OR_RETURN(in.is_dynamic(), index.is_dynamic()); - user_op::TensorDesc* out = ctx->OutputTensorDesc("output", 0); + user_op::TensorDesc* out = ctx->MutOutputTensorDesc("output", 0); *out->mut_shape() = index.shape(); return Maybe::Ok(); @@ -87,7 +87,7 @@ namespace oneflow { const user_op::TensorDesc& index = ctx->InputTensorDesc("index", 0); CHECK_OR_RETURN(IsIndexDataType(index.data_type())); const user_op::TensorDesc& in = ctx->InputTensorDesc("input", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("output", 0); + user_op::TensorDesc* out = ctx->MutOutputTensorDesc("output", 0); *out->mut_data_type() = in.data_type(); return Maybe::Ok(); } diff --git a/oneflow/user/ops/dim_scatter_ops.cpp b/oneflow/user/ops/dim_scatter_ops.cpp index 42759138456..8f378d059a3 100644 --- a/oneflow/user/ops/dim_scatter_ops.cpp +++ b/oneflow/user/ops/dim_scatter_ops.cpp @@ -73,7 +73,7 @@ Maybe InferTensorDesc(user_op::InferContext* ctx) { CHECK_LE_OR_RETURN(index.shape().At(i), src.shape().At(i)); } - user_op::TensorDesc* out = ctx->OutputTensorDesc("output", 0); + user_op::TensorDesc* out = ctx->MutOutputTensorDesc("output", 0); *out->mut_shape() = input ? input->shape() : like->shape(); return Maybe::Ok(); } @@ -96,7 +96,7 @@ Maybe InferScalarTensorDesc(user_op::InferContext* ctx) { CHECK_LE_OR_RETURN(index.shape().At(i), input.shape().At(i)); } - user_op::TensorDesc* out = ctx->OutputTensorDesc("output", 0); + user_op::TensorDesc* out = ctx->MutOutputTensorDesc("output", 0); *out->mut_shape() = input.shape(); return Maybe::Ok(); } diff --git a/oneflow/user/ops/elementwise_maximum_minimum_ops.cpp b/oneflow/user/ops/elementwise_maximum_minimum_ops.cpp index 7a143bb4ecd..f0135503e0d 100644 --- a/oneflow/user/ops/elementwise_maximum_minimum_ops.cpp +++ b/oneflow/user/ops/elementwise_maximum_minimum_ops.cpp @@ -47,8 +47,8 @@ Maybe InferTensorDesc_(InferContext* ctx) { CHECK_EQ_OR_RETURN(tensor_x.shape().At(i), tensor_y.shape().At(i)); } - TensorDesc* tensor_dx = ctx->OutputTensorDesc("dx", 0); - TensorDesc* tensor_dy = ctx->OutputTensorDesc("dy", 0); + TensorDesc* tensor_dx = ctx->MutOutputTensorDesc("dx", 0); + TensorDesc* tensor_dy = ctx->MutOutputTensorDesc("dy", 0); if (tensor_dx) { *tensor_dx->mut_shape() = tensor_x.shape(); } @@ -59,8 +59,8 @@ Maybe InferTensorDesc_(InferContext* ctx) { Maybe InferDataType_(InferContext* ctx) { const TensorDesc& tensor_dz = ctx->InputTensorDesc("dz", 0); - TensorDesc* tensor_dx = ctx->OutputTensorDesc("dx", 0); - TensorDesc* tensor_dy = ctx->OutputTensorDesc("dy", 0); + TensorDesc* tensor_dx = ctx->MutOutputTensorDesc("dx", 0); + TensorDesc* tensor_dy = ctx->MutOutputTensorDesc("dy", 0); if (tensor_dx) { *tensor_dx->mut_data_type() = tensor_dz.data_type(); } diff --git a/oneflow/user/ops/embedding_op.cpp b/oneflow/user/ops/embedding_op.cpp index b854ae9cf87..6b33338eb6d 100644 --- a/oneflow/user/ops/embedding_op.cpp +++ b/oneflow/user/ops/embedding_op.cpp @@ -46,7 +46,7 @@ namespace oneflow { indices_shape.dim_vec().cend()); out_dim_vec.push_back(weight_shape.At(1)); - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); *out_desc->mut_shape() = Shape(out_dim_vec); return Maybe::Ok(); } @@ -87,7 +87,7 @@ namespace oneflow { /* static */ Maybe EmbeddingGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& weight_shape = ctx->InputShape("weight", 0); - user_op::TensorDesc* dx_desc = ctx->OutputTensorDesc("dx", 0); + user_op::TensorDesc* dx_desc = ctx->MutOutputTensorDesc("dx", 0); *dx_desc->mut_shape() = weight_shape; return Maybe::Ok(); diff --git a/oneflow/user/ops/flatten_op.cpp b/oneflow/user/ops/flatten_op.cpp index c7798d56fb5..9c0f05b2903 100644 --- a/oneflow/user/ops/flatten_op.cpp +++ b/oneflow/user/ops/flatten_op.cpp @@ -22,7 +22,7 @@ namespace oneflow { const int32_t start_dim = ctx->Attr("start_dim"); const int32_t end_dim = ctx->Attr("end_dim"); const user_op::TensorDesc& in_tensor_desc = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out_tensor_desc = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out_tensor_desc = ctx->MutOutputTensorDesc("out", 0); const Shape& in_shape = ExpandDimIf0D(in_tensor_desc.shape()); CHECK_GE_OR_RETURN(start_dim, 0); CHECK_LT_OR_RETURN(start_dim, in_shape.NumAxes()); diff --git a/oneflow/user/ops/flip_op.cpp b/oneflow/user/ops/flip_op.cpp index 24af3a3f4b7..7f9238885bc 100644 --- a/oneflow/user/ops/flip_op.cpp +++ b/oneflow/user/ops/flip_op.cpp @@ -24,7 +24,7 @@ namespace oneflow { const std::vector dims = ctx->Attr>("dims"); CHECK_OR_RETURN(dims.size() <= input_dims) << "len of dims must less than len of input tensor"; for (auto x : dims) { CHECK_OR_RETURN(x < input_dims) << "dims parameter is illegal."; } - user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); + user_op::TensorDesc* y_desc = ctx->MutOutputTensorDesc("y", 0); *y_desc->mut_shape() = x_desc.shape(); return Maybe::Ok(); } diff --git a/oneflow/user/ops/fused_cast_scale_op.cpp b/oneflow/user/ops/fused_cast_scale_op.cpp index 816a10efb06..4825a354705 100644 --- a/oneflow/user/ops/fused_cast_scale_op.cpp +++ b/oneflow/user/ops/fused_cast_scale_op.cpp @@ -23,7 +23,7 @@ Maybe FusedCastScaleOp::InferLogicalTensorDesc(user_op::InferContext* ctx) const user_op::TensorDesc& scale_by_tensor = ctx->InputTensorDesc("scale_by_tensor", 0); CHECK_EQ_OR_RETURN(scale_by_tensor.shape().NumAxes(), 1); CHECK_EQ_OR_RETURN(scale_by_tensor.shape().At(0), 1); - user_op::TensorDesc* y = ctx->OutputTensorDesc("y", 0); + user_op::TensorDesc* y = ctx->MutOutputTensorDesc("y", 0); *y->mut_is_dynamic() = x.is_dynamic(); *y->mut_shape() = x.shape(); return Maybe::Ok(); @@ -35,7 +35,7 @@ Maybe FusedCastScaleOp::InferPhysicalTensorDesc(user_op::InferContext* ctx Maybe FusedCastScaleOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& scale_by_tensor = ctx->InputTensorDesc("scale_by_tensor", 0); - user_op::TensorDesc* y = ctx->OutputTensorDesc("y", 0); + user_op::TensorDesc* y = ctx->MutOutputTensorDesc("y", 0); *y->mut_data_type() = scale_by_tensor.data_type(); return Maybe::Ok(); } diff --git a/oneflow/user/ops/fused_matmul_bias_add_relu_dropout_op.cpp b/oneflow/user/ops/fused_matmul_bias_add_relu_dropout_op.cpp index ced41d69fd8..6f57ee3c713 100644 --- a/oneflow/user/ops/fused_matmul_bias_add_relu_dropout_op.cpp +++ b/oneflow/user/ops/fused_matmul_bias_add_relu_dropout_op.cpp @@ -84,16 +84,16 @@ Maybe InferDataType4Matmul(user_op::InferContext* ctx) { << "The Input's datatype should be equal. "; } - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); *out_desc->mut_data_type() = first_in_desc.data_type(); for (int32_t i = 0; i < ctx->output_size("hidden"); i++) { - user_op::TensorDesc* hidden_desc = ctx->OutputTensorDesc("hidden", i); + user_op::TensorDesc* hidden_desc = ctx->MutOutputTensorDesc("hidden", i); *hidden_desc->mut_data_type() = first_in_desc.data_type(); } for (int32_t i = 0; i < ctx->output_size("cublas_aux"); i++) { - user_op::TensorDesc* aux_desc = ctx->OutputTensorDesc("cublas_aux", i); + user_op::TensorDesc* aux_desc = ctx->MutOutputTensorDesc("cublas_aux", i); *aux_desc->mut_data_type() = DataType::kInt32; } diff --git a/oneflow/user/ops/fused_scale_mask_softmax_dropout_op.cpp b/oneflow/user/ops/fused_scale_mask_softmax_dropout_op.cpp index f17bbc33b93..b126f7754a1 100644 --- a/oneflow/user/ops/fused_scale_mask_softmax_dropout_op.cpp +++ b/oneflow/user/ops/fused_scale_mask_softmax_dropout_op.cpp @@ -95,7 +95,7 @@ namespace oneflow { CHECK_EQ_OR_RETURN(dy_desc.shape().At(dy_desc.shape().NumAxes() - 1), mask_desc.shape().At(mask_desc.shape().NumAxes() - 1)) << " last dim of y and mask is not equal."; - user_op::TensorDesc* dx_desc = ctx->OutputTensorDesc("dx", 0); + user_op::TensorDesc* dx_desc = ctx->MutOutputTensorDesc("dx", 0); *dx_desc->mut_shape() = dy_desc.shape(); *dx_desc->mut_is_dynamic() = dy_desc.is_dynamic(); return Maybe::Ok(); @@ -112,7 +112,7 @@ namespace oneflow { CHECK_EQ_OR_RETURN(dy_desc.data_type(), softmax_y_desc.data_type()) << " dy and softmax_y dtype must equal"; CHECK_EQ_OR_RETURN(mask_desc.data_type(), DataType::kBool) << " mask dtype only support bool."; - user_op::TensorDesc* dx_desc = ctx->OutputTensorDesc("dx", 0); + user_op::TensorDesc* dx_desc = ctx->MutOutputTensorDesc("dx", 0); *dx_desc->mut_data_type() = dy_desc.data_type(); return Maybe::Ok(); } diff --git a/oneflow/user/ops/fused_scale_mask_softmax_op.cpp b/oneflow/user/ops/fused_scale_mask_softmax_op.cpp index 5d139382288..ee9d553e509 100644 --- a/oneflow/user/ops/fused_scale_mask_softmax_op.cpp +++ b/oneflow/user/ops/fused_scale_mask_softmax_op.cpp @@ -83,7 +83,7 @@ namespace oneflow { CHECK_EQ_OR_RETURN(y_desc.shape().At(y_desc.shape().NumAxes() - 1), mask_desc.shape().At(mask_desc.shape().NumAxes() - 1)) << " last dim of y and mask is not equal."; - user_op::TensorDesc* dx_desc = ctx->OutputTensorDesc("dx", 0); + user_op::TensorDesc* dx_desc = ctx->MutOutputTensorDesc("dx", 0); *dx_desc->mut_shape() = dy_desc.shape(); *dx_desc->mut_is_dynamic() = dy_desc.is_dynamic(); return Maybe::Ok(); @@ -99,7 +99,7 @@ namespace oneflow { const user_op::TensorDesc& mask_desc = ctx->InputTensorDesc("mask", 0); CHECK_EQ_OR_RETURN(dy_desc.data_type(), y_desc.data_type()) << " dy and y dtype must equal"; CHECK_EQ_OR_RETURN(mask_desc.data_type(), DataType::kBool) << " mask dtype only support bool."; - user_op::TensorDesc* dx_desc = ctx->OutputTensorDesc("dx", 0); + user_op::TensorDesc* dx_desc = ctx->MutOutputTensorDesc("dx", 0); *dx_desc->mut_data_type() = dy_desc.data_type(); return Maybe::Ok(); } diff --git a/oneflow/user/ops/fused_scale_tril_softmax_mask_scale_op.cpp b/oneflow/user/ops/fused_scale_tril_softmax_mask_scale_op.cpp index 0a6e2cc00be..7d6573b9a96 100644 --- a/oneflow/user/ops/fused_scale_tril_softmax_mask_scale_op.cpp +++ b/oneflow/user/ops/fused_scale_tril_softmax_mask_scale_op.cpp @@ -63,7 +63,7 @@ namespace oneflow { user_op::InferContext* ctx) -> Maybe { const user_op::TensorDesc& softmax_y_desc = ctx->InputTensorDesc("softmax_y", 0); const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); - user_op::TensorDesc* dx_desc = ctx->OutputTensorDesc("dx", 0); + user_op::TensorDesc* dx_desc = ctx->MutOutputTensorDesc("dx", 0); CHECK_OR_RETURN(dy_desc.shape() == softmax_y_desc.shape()); *dx_desc->mut_shape() = dy_desc.shape(); *dx_desc->mut_is_dynamic() = dy_desc.is_dynamic(); @@ -77,7 +77,7 @@ namespace oneflow { -> Maybe { const user_op::TensorDesc& softmax_y_desc = ctx->InputTensorDesc("softmax_y", 0); const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); - user_op::TensorDesc* dx_desc = ctx->OutputTensorDesc("dx", 0); + user_op::TensorDesc* dx_desc = ctx->MutOutputTensorDesc("dx", 0); CHECK_OR_RETURN(dy_desc.data_type() == softmax_y_desc.data_type()); *dx_desc->mut_data_type() = dy_desc.data_type(); return Maybe::Ok(); diff --git a/oneflow/user/ops/gather_op.cpp b/oneflow/user/ops/gather_op.cpp index 34fd62b74d5..224a73eb3c6 100644 --- a/oneflow/user/ops/gather_op.cpp +++ b/oneflow/user/ops/gather_op.cpp @@ -25,7 +25,7 @@ namespace oneflow { const user_op::TensorDesc& indices = ctx->InputTensorDesc("indices", 0); // For 0-dim Tensor CHECK_GE_OR_RETURN(indices.shape().NumAxes(), 0); // NOLINT - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); DimVector dim_vec; dim_vec.insert(dim_vec.end(), in.shape().dim_vec().cbegin(), @@ -83,7 +83,7 @@ namespace oneflow { /*static*/ auto GatherOp::InferDataType(user_op::InferContext* ctx) -> Maybe { const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); const user_op::TensorDesc& indices = ctx->InputTensorDesc("indices", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); CHECK_OR_RETURN(IsIndexDataType(indices.data_type())); *out->mut_data_type() = in.data_type(); return Maybe::Ok(); diff --git a/oneflow/user/ops/gpt_data_loader_op.cpp b/oneflow/user/ops/gpt_data_loader_op.cpp index b6a1e56f926..c8906a14c71 100644 --- a/oneflow/user/ops/gpt_data_loader_op.cpp +++ b/oneflow/user/ops/gpt_data_loader_op.cpp @@ -22,13 +22,13 @@ namespace oneflow { -> Maybe { int64_t batch_size = ctx->Attr("batch_size"); int64_t sample_len = ctx->Attr("seq_length") + ctx->Attr("label_length"); - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); *out_desc->mut_shape() = Shape({batch_size, sample_len}); return Maybe::Ok(); } /*static*/ auto MegatronGptMmapDataLoaderOp::InferDataType(user_op::InferContext* ctx) -> Maybe { - *ctx->OutputTensorDesc("out", 0)->mut_data_type() = ctx->Attr("dtype"); + *ctx->MutOutputTensorDesc("out", 0)->mut_data_type() = ctx->Attr("dtype"); return Maybe::Ok(); } /*static*/ auto MegatronGptMmapDataLoaderOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { diff --git a/oneflow/user/ops/grid_sample_op.cpp b/oneflow/user/ops/grid_sample_op.cpp index 45ff68c888d..78360af1151 100644 --- a/oneflow/user/ops/grid_sample_op.cpp +++ b/oneflow/user/ops/grid_sample_op.cpp @@ -47,7 +47,7 @@ Maybe GridSampleOp::CheckAttr(const user_op::UserOpDefWrapper& def, /*static*/ auto GridSampleOp::InferLogicalTensorDesc(user_op::InferContext* ctx) -> Maybe { const user_op::TensorDesc& input = ctx->InputTensorDesc("input", 0); const user_op::TensorDesc& grid = ctx->InputTensorDesc("grid", 0); - user_op::TensorDesc& output = *(ctx->OutputTensorDesc("output", 0)); + user_op::TensorDesc& output = *(ctx->MutOutputTensorDesc("output", 0)); // Only support 4D or 5D input with NCHW layout // For 4D grid: input = { N, C, H_in, W_in }, // grid = { N, H_out, W_out, 2 } @@ -111,8 +111,8 @@ Maybe GridSampleGradOp::CheckAttr(const user_op::UserOpDefWrapper& def, /*static*/ auto GridSampleGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) -> Maybe { - *(ctx->OutputTensorDesc("dinput", 0)->mut_shape()) = ctx->InputTensorDesc("input", 0).shape(); - *(ctx->OutputTensorDesc("dgrid", 0)->mut_shape()) = ctx->InputTensorDesc("grid", 0).shape(); + *(ctx->MutOutputTensorDesc("dinput", 0)->mut_shape()) = ctx->InputTensorDesc("input", 0).shape(); + *(ctx->MutOutputTensorDesc("dgrid", 0)->mut_shape()) = ctx->InputTensorDesc("grid", 0).shape(); return Maybe::Ok(); } /*static*/ auto GridSampleGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) diff --git a/oneflow/user/ops/image_batch_align_op.cpp b/oneflow/user/ops/image_batch_align_op.cpp index 0563281485b..868663ff9ee 100644 --- a/oneflow/user/ops/image_batch_align_op.cpp +++ b/oneflow/user/ops/image_batch_align_op.cpp @@ -36,7 +36,7 @@ bool PowerOfTwo(T x) { DimVector dim_vec(shape_attr.NumAxes() + 1); dim_vec.at(0) = in_desc.shape().elem_cnt(); FOR_RANGE(int64_t, i, 0, shape_attr.NumAxes()) { dim_vec.at(i + 1) = shape_attr.At(i); } - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); *out_desc->mut_shape() = Shape(dim_vec); out_desc->set_is_dynamic(dynamic_out); return Maybe::Ok(); @@ -90,7 +90,7 @@ bool PowerOfTwo(T x) { /* static */ Maybe ImageBatchAlignOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); CHECK_OR_RETURN(in_desc.data_type() == DataType::kTensorBuffer); - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); *out_desc->mut_data_type() = ctx->Attr("data_type"); return Maybe::Ok(); } diff --git a/oneflow/user/ops/image_decode_op.cpp b/oneflow/user/ops/image_decode_op.cpp index cd308ce528e..7cd4c7cb4e8 100644 --- a/oneflow/user/ops/image_decode_op.cpp +++ b/oneflow/user/ops/image_decode_op.cpp @@ -21,7 +21,7 @@ namespace oneflow { /* static */ Maybe ImageDecodeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); CHECK_OR_RETURN(in_desc.shape().NumAxes() == 1 && in_desc.shape().At(0) >= 1); - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); *out_desc->mut_shape() = in_desc.shape(); return Maybe::Ok(); } @@ -58,7 +58,7 @@ namespace oneflow { /* static */ Maybe ImageDecodeOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); CHECK_OR_RETURN(in_desc.data_type() == DataType::kTensorBuffer); - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); *out_desc->mut_data_type() = DataType::kTensorBuffer; return Maybe::Ok(); } diff --git a/oneflow/user/ops/image_preprocess_ops.cpp b/oneflow/user/ops/image_preprocess_ops.cpp index 8279db56dcb..e9f82dbcc2b 100644 --- a/oneflow/user/ops/image_preprocess_ops.cpp +++ b/oneflow/user/ops/image_preprocess_ops.cpp @@ -31,7 +31,7 @@ namespace oneflow { CHECK_OR_RETURN(mirror_tensor.shape().NumAxes() == 1 && in_tensor.shape().At(0) == mirror_tensor.shape().At(0)); } - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc("out", 0); int64_t N = in_tensor.shape().At(0); int64_t H = ctx->Attr("crop_h"); int64_t W = ctx->Attr("crop_w"); @@ -71,7 +71,7 @@ namespace oneflow { CHECK_EQ_OR_RETURN(mirror_tensor.data_type(), DataType::kInt8); } - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc("out", 0); DataType output_dtype = ctx->Attr("output_dtype"); CHECK_EQ_OR_RETURN(output_dtype, DataType::kFloat); // only support float now; for float16 in future @@ -89,7 +89,7 @@ namespace oneflow { CHECK_OR_RETURN(mirror_tensor.shape().NumAxes() == 1 && in_tensor.shape().At(0) == mirror_tensor.shape().At(0)); } - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc("out", 0); int64_t N = in_tensor.shape().At(0); int64_t H = ctx->Attr("crop_h"); int64_t W = ctx->Attr("crop_w"); @@ -134,7 +134,7 @@ namespace oneflow { const user_op::TensorDesc& mirror_tensor = ctx->InputTensorDesc("mirror", 0); CHECK_EQ_OR_RETURN(mirror_tensor.data_type(), DataType::kInt8); } - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc("out", 0); DataType output_dtype = ctx->Attr("output_dtype"); CHECK_EQ_OR_RETURN(output_dtype, DataType::kFloat); // only support float now; for float16 in future @@ -143,7 +143,7 @@ namespace oneflow { } /* static */ Maybe CoinFlipOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc("out", 0); int64_t batch_size = ctx->Attr("batch_size"); *out_tensor->mut_shape() = Shape({batch_size}); return Maybe::Ok(); @@ -202,14 +202,14 @@ namespace oneflow { } /* static */ Maybe CoinFlipOp::InferDataType(user_op::InferContext* ctx) { - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc("out", 0); *out_tensor->mut_data_type() = DataType::kInt8; return Maybe::Ok(); } /* static */ Maybe ImageRandomCropOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc("out", 0); *out_tensor->mut_shape() = in_tensor.shape(); *out_tensor->mut_is_dynamic() = in_tensor.is_dynamic(); return Maybe::Ok(); diff --git a/oneflow/user/ops/image_resize_ops.cpp b/oneflow/user/ops/image_resize_ops.cpp index fe6f351ecaf..a899dcae44a 100644 --- a/oneflow/user/ops/image_resize_ops.cpp +++ b/oneflow/user/ops/image_resize_ops.cpp @@ -27,11 +27,11 @@ namespace oneflow { int64_t target_height = ctx->Attr("target_height"); int64_t channels = ctx->Attr("channels"); - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc("out", 0); *out_tensor->mut_shape() = Shape({batch_size, target_height, target_width, channels}); out_tensor->set_is_dynamic(in_tensor.is_dynamic()); - user_op::TensorDesc* scale_tensor = ctx->OutputTensorDesc("scale", 0); + user_op::TensorDesc* scale_tensor = ctx->MutOutputTensorDesc("scale", 0); *scale_tensor->mut_shape() = Shape({batch_size, 2}); scale_tensor->set_is_dynamic(in_tensor.is_dynamic()); @@ -77,9 +77,9 @@ namespace oneflow { /* static */ Maybe ImageResizeToFixedOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); CHECK_OR_RETURN(in_tensor.data_type() == DataType::kTensorBuffer); - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc("out", 0); *out_tensor->mut_data_type() = ctx->Attr("data_type"); - user_op::TensorDesc* scale_tensor = ctx->OutputTensorDesc("scale", 0); + user_op::TensorDesc* scale_tensor = ctx->MutOutputTensorDesc("scale", 0); *scale_tensor->mut_data_type() = DataType::kFloat; return Maybe::Ok(); } @@ -88,11 +88,11 @@ namespace oneflow { user_op::InferContext* ctx) { const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); CHECK_OR_RETURN(in_desc.shape().NumAxes() == 1 && in_desc.shape().At(0) > 0); - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); *out_desc->mut_shape() = in_desc.shape(); - user_op::TensorDesc* size_desc = ctx->OutputTensorDesc("size", 0); + user_op::TensorDesc* size_desc = ctx->MutOutputTensorDesc("size", 0); *size_desc->mut_shape() = in_desc.shape(); - user_op::TensorDesc* scale_desc = ctx->OutputTensorDesc("scale", 0); + user_op::TensorDesc* scale_desc = ctx->MutOutputTensorDesc("scale", 0); *scale_desc->mut_shape() = in_desc.shape(); return Maybe::Ok(); } @@ -132,11 +132,11 @@ namespace oneflow { /* static */ Maybe ImageResizeKeepAspectRatioOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); CHECK_OR_RETURN(in_desc.data_type() == DataType::kTensorBuffer); - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); *out_desc->mut_data_type() = DataType::kTensorBuffer; - user_op::TensorDesc* size_desc = ctx->OutputTensorDesc("size", 0); + user_op::TensorDesc* size_desc = ctx->MutOutputTensorDesc("size", 0); *size_desc->mut_data_type() = DataType::kTensorBuffer; - user_op::TensorDesc* scale_desc = ctx->OutputTensorDesc("scale", 0); + user_op::TensorDesc* scale_desc = ctx->MutOutputTensorDesc("scale", 0); *scale_desc->mut_data_type() = DataType::kTensorBuffer; return Maybe::Ok(); } diff --git a/oneflow/user/ops/image_target_resize_op.cpp b/oneflow/user/ops/image_target_resize_op.cpp index 49d7db09479..b3212ad05ed 100644 --- a/oneflow/user/ops/image_target_resize_op.cpp +++ b/oneflow/user/ops/image_target_resize_op.cpp @@ -21,11 +21,11 @@ namespace oneflow { /* static */ Maybe ImageTargetResizeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); CHECK_OR_RETURN(in_desc.shape().NumAxes() == 1 && in_desc.shape().At(0) >= 1); - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); *out_desc->mut_shape() = in_desc.shape(); - user_op::TensorDesc* size_desc = ctx->OutputTensorDesc("size", 0); + user_op::TensorDesc* size_desc = ctx->MutOutputTensorDesc("size", 0); *size_desc->mut_shape() = Shape({in_desc.shape().elem_cnt(), 2}); - user_op::TensorDesc* scale_desc = ctx->OutputTensorDesc("scale", 0); + user_op::TensorDesc* scale_desc = ctx->MutOutputTensorDesc("scale", 0); *scale_desc->mut_shape() = Shape({in_desc.shape().elem_cnt(), 2}); return Maybe::Ok(); } @@ -61,11 +61,11 @@ namespace oneflow { /* static */ Maybe ImageTargetResizeOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); CHECK_OR_RETURN(in_desc.data_type() == DataType::kTensorBuffer); - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); *out_desc->mut_data_type() = DataType::kTensorBuffer; - user_op::TensorDesc* size_desc = ctx->OutputTensorDesc("size", 0); + user_op::TensorDesc* size_desc = ctx->MutOutputTensorDesc("size", 0); *size_desc->mut_data_type() = DataType::kInt32; - user_op::TensorDesc* scale_desc = ctx->OutputTensorDesc("scale", 0); + user_op::TensorDesc* scale_desc = ctx->MutOutputTensorDesc("scale", 0); *scale_desc->mut_data_type() = DataType::kFloat; return Maybe::Ok(); } diff --git a/oneflow/user/ops/in_top_k_op.cpp b/oneflow/user/ops/in_top_k_op.cpp index 3d76b2b9110..0a2dc857cab 100644 --- a/oneflow/user/ops/in_top_k_op.cpp +++ b/oneflow/user/ops/in_top_k_op.cpp @@ -21,7 +21,7 @@ namespace oneflow { /* static */ Maybe InTopKOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& targets = ctx->InputTensorDesc("targets", 0); const user_op::TensorDesc& predictions = ctx->InputTensorDesc("predictions", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); CHECK_EQ_OR_RETURN(targets.shape().NumAxes(), 1); CHECK_EQ_OR_RETURN(predictions.shape().NumAxes(), 2); const bool is_dynamic = targets.is_dynamic(); @@ -45,7 +45,7 @@ namespace oneflow { CHECK_OR_RETURN(IsIndexDataType(targets.data_type())); const user_op::TensorDesc& predictions = ctx->InputTensorDesc("predictions", 0); CHECK_EQ_OR_RETURN(predictions.data_type(), DataType::kFloat); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); *out->mut_data_type() = kBool; return Maybe::Ok(); } diff --git a/oneflow/user/ops/indexed_slices_reduce_sum_op.cpp b/oneflow/user/ops/indexed_slices_reduce_sum_op.cpp index 5b61c8ff2ba..1710d8ce65c 100644 --- a/oneflow/user/ops/indexed_slices_reduce_sum_op.cpp +++ b/oneflow/user/ops/indexed_slices_reduce_sum_op.cpp @@ -29,13 +29,13 @@ namespace oneflow { const int64_t n = x_indices.shape().elem_cnt(); const int64_t m = x_values.shape().elem_cnt() / n; - user_op::TensorDesc* y_indices = ctx->OutputTensorDesc("y_indices", 0); - user_op::TensorDesc* y_values = ctx->OutputTensorDesc("y_values", 0); + user_op::TensorDesc* y_indices = ctx->MutOutputTensorDesc("y_indices", 0); + user_op::TensorDesc* y_values = ctx->MutOutputTensorDesc("y_values", 0); *y_indices = x_indices; *y_indices->mut_shape() = Shape({n}); *y_values = x_values; *y_values->mut_shape() = Shape({n, m}); - user_op::TensorDesc* num_unique = ctx->OutputTensorDesc("num_unique", 0); + user_op::TensorDesc* num_unique = ctx->MutOutputTensorDesc("num_unique", 0); *num_unique->mut_shape() = Shape({1}); return Maybe::Ok(); } @@ -52,7 +52,7 @@ namespace oneflow { /* static */ Maybe IndexedSlicesReduceSumOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& x_indices = ctx->InputTensorDesc("x_indices", 0); CHECK_OR_RETURN(IsIndexDataType(x_indices.data_type())); - user_op::TensorDesc* num_unique = ctx->OutputTensorDesc("num_unique", 0); + user_op::TensorDesc* num_unique = ctx->MutOutputTensorDesc("num_unique", 0); *num_unique->mut_data_type() = DataType::kInt64; return Maybe::Ok(); } diff --git a/oneflow/user/ops/kl_div_op.cpp b/oneflow/user/ops/kl_div_op.cpp index 636e1680015..a2e915ada1d 100644 --- a/oneflow/user/ops/kl_div_op.cpp +++ b/oneflow/user/ops/kl_div_op.cpp @@ -26,7 +26,7 @@ Maybe KlInferTensorDescFn(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(input_desc.is_dynamic(), target_desc.is_dynamic()); CHECK_EQ_OR_RETURN(input_desc.shape(), target_desc.shape()); - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); *out_desc->mut_is_dynamic() = input_desc.is_dynamic(); *out_desc->mut_shape() = input_desc.shape(); @@ -51,7 +51,7 @@ Maybe InferGradTensorDescFn(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(input_desc.shape(), target_desc.shape()); CHECK_EQ_OR_RETURN(dy_desc.shape(), target_desc.shape()); - user_op::TensorDesc* dx_desc = ctx->OutputTensorDesc("dx", 0); + user_op::TensorDesc* dx_desc = ctx->MutOutputTensorDesc("dx", 0); *dx_desc->mut_is_dynamic() = input_desc.is_dynamic(); *dx_desc->mut_shape() = input_desc.shape(); diff --git a/oneflow/user/ops/layer_norm_op.cpp b/oneflow/user/ops/layer_norm_op.cpp index 3ae2765b362..d900c0d0d59 100644 --- a/oneflow/user/ops/layer_norm_op.cpp +++ b/oneflow/user/ops/layer_norm_op.cpp @@ -43,9 +43,9 @@ oneflow::DataType InferBnParamDataType(const DataType x_data_type) { /* static */ Maybe LayerNormOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); - user_op::TensorDesc* y = ctx->OutputTensorDesc("y", 0); - user_op::TensorDesc* mean = ctx->OutputTensorDesc("mean", 0); - user_op::TensorDesc* inv_variance = ctx->OutputTensorDesc("inv_variance", 0); + user_op::TensorDesc* y = ctx->MutOutputTensorDesc("y", 0); + user_op::TensorDesc* mean = ctx->MutOutputTensorDesc("mean", 0); + user_op::TensorDesc* inv_variance = ctx->MutOutputTensorDesc("inv_variance", 0); const bool center = ctx->Attr("center"); const bool scale = ctx->Attr("scale"); const int64_t begin_params_axis = @@ -99,7 +99,7 @@ oneflow::DataType InferBnParamDataType(const DataType x_data_type) { /* static */ Maybe LayerNormOp::InferDataType(user_op::InferContext* ctx) { const bool center = ctx->Attr("center"); const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); - user_op::TensorDesc* y = ctx->OutputTensorDesc("y", 0); + user_op::TensorDesc* y = ctx->MutOutputTensorDesc("y", 0); *y->mut_data_type() = x.data_type(); if (center) { const user_op::TensorDesc& beta = ctx->InputTensorDesc("beta", 0); @@ -110,8 +110,8 @@ oneflow::DataType InferBnParamDataType(const DataType x_data_type) { const user_op::TensorDesc& gamma = ctx->InputTensorDesc("gamma", 0); CHECK_EQ_OR_RETURN(gamma.data_type(), x.data_type()); } - user_op::TensorDesc* mean = ctx->OutputTensorDesc("mean", 0); - user_op::TensorDesc* inv_variance = ctx->OutputTensorDesc("inv_variance", 0); + user_op::TensorDesc* mean = ctx->MutOutputTensorDesc("mean", 0); + user_op::TensorDesc* inv_variance = ctx->MutOutputTensorDesc("inv_variance", 0); *mean->mut_data_type() = InferBnParamDataType(x.data_type()); *inv_variance->mut_data_type() = mean->data_type(); return Maybe::Ok(); @@ -122,7 +122,7 @@ oneflow::DataType InferBnParamDataType(const DataType x_data_type) { const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); const user_op::TensorDesc& mean = ctx->InputTensorDesc("mean", 0); const user_op::TensorDesc& inv_variance = ctx->InputTensorDesc("inv_variance", 0); - user_op::TensorDesc* dx = ctx->OutputTensorDesc("dx", 0); + user_op::TensorDesc* dx = ctx->MutOutputTensorDesc("dx", 0); CHECK_EQ_OR_RETURN(dy.shape(), x.shape()); const int64_t begin_norm_axis = ctx->Attr("begin_norm_axis"); CHECK_GT_OR_RETURN(begin_norm_axis, 0); @@ -167,7 +167,7 @@ oneflow::DataType InferBnParamDataType(const DataType x_data_type) { const DataType& bn_param_data_type = InferBnParamDataType(x.data_type()); CHECK_EQ_OR_RETURN(mean.data_type(), bn_param_data_type); CHECK_EQ_OR_RETURN(inv_variance.data_type(), bn_param_data_type); - user_op::TensorDesc* dx = ctx->OutputTensorDesc("dx", 0); + user_op::TensorDesc* dx = ctx->MutOutputTensorDesc("dx", 0); *dx->mut_data_type() = dy.data_type(); if (ctx->has_input("_add_to_output", 0)) { const auto& add_to_output = ctx->InputTensorDesc("_add_to_output", 0); @@ -200,11 +200,11 @@ oneflow::DataType InferBnParamDataType(const DataType x_data_type) { dy.shape().dim_vec().cend()); const Shape param_shape(param_shape_dim_vec); if (has_beta_diff) { - user_op::TensorDesc* beta_diff = ctx->OutputTensorDesc("beta_diff", 0); + user_op::TensorDesc* beta_diff = ctx->MutOutputTensorDesc("beta_diff", 0); *beta_diff->mut_shape() = param_shape; } if (has_gamma_diff) { - user_op::TensorDesc* gamma_diff = ctx->OutputTensorDesc("gamma_diff", 0); + user_op::TensorDesc* gamma_diff = ctx->MutOutputTensorDesc("gamma_diff", 0); *gamma_diff->mut_shape() = param_shape; } return Maybe::Ok(); @@ -237,11 +237,11 @@ oneflow::DataType InferBnParamDataType(const DataType x_data_type) { const bool has_gamma_diff = has_tensor("gamma_diff"); const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); if (has_beta_diff) { - user_op::TensorDesc* beta_diff = ctx->OutputTensorDesc("beta_diff", 0); + user_op::TensorDesc* beta_diff = ctx->MutOutputTensorDesc("beta_diff", 0); *beta_diff->mut_data_type() = dy.data_type(); } if (has_gamma_diff) { - user_op::TensorDesc* gamma_diff = ctx->OutputTensorDesc("gamma_diff", 0); + user_op::TensorDesc* gamma_diff = ctx->MutOutputTensorDesc("gamma_diff", 0); *gamma_diff->mut_data_type() = dy.data_type(); } return Maybe::Ok(); diff --git a/oneflow/user/ops/math_binary_broadcast_ops.cpp b/oneflow/user/ops/math_binary_broadcast_ops.cpp index bf21d92d548..b86b9416c76 100644 --- a/oneflow/user/ops/math_binary_broadcast_ops.cpp +++ b/oneflow/user/ops/math_binary_broadcast_ops.cpp @@ -31,7 +31,7 @@ bool IsZeroDimTensor(const user_op::TensorDesc* tensor) { return tensor->shape() Maybe InferTensorDescBinaryBroadcastNormal(user_op::InferContext* ctx) { const user_op::TensorDesc& tensor_x = ctx->InputTensorDesc("x", 0); const user_op::TensorDesc& tensor_y = ctx->InputTensorDesc("y", 0); - user_op::TensorDesc* tensor_z = ctx->OutputTensorDesc("z", 0); + user_op::TensorDesc* tensor_z = ctx->MutOutputTensorDesc("z", 0); size_t output_num_axes = std::max(tensor_x.shape().NumAxes(), tensor_y.shape().NumAxes()); if (IsZeroDimTensor(&tensor_x)) { diff --git a/oneflow/user/ops/matmul_op.cpp b/oneflow/user/ops/matmul_op.cpp index 0b650ef622d..48fb0cb0d86 100644 --- a/oneflow/user/ops/matmul_op.cpp +++ b/oneflow/user/ops/matmul_op.cpp @@ -34,7 +34,7 @@ Maybe InferTensorDesc4Matmul(user_op::InferContext* ctx) { for (int i = 0; i < num_axes - 2; ++i) { CHECK_EQ_OR_RETURN(a.shape().At(i), b.shape().At(i)); } } - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); *ctx->MutOutputShape("out", 0) = ctx->InputShape("a", 0); *ctx->MutOutputIsDynamic("out", 0) = ctx->InputIsDynamic("a", 0); @@ -286,7 +286,7 @@ void GenBackwardOpConf4Matmul(const std::string& op_type_name, const user_op::Us const user_op::TensorDesc& a = ctx->InputTensorDesc("a", 0); const user_op::TensorDesc& b = ctx->InputTensorDesc("b", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); const int64_t num_a_dims = a.shape().NumAxes(); const int64_t num_b_dims = b.shape().NumAxes(); @@ -475,7 +475,7 @@ void GenBackwardOpConf4Matmul(const std::string& op_type_name, const user_op::Us user_op::InferContext* ctx) { const user_op::TensorDesc& a = ctx->InputTensorDesc("a", 0); const user_op::TensorDesc& b = ctx->InputTensorDesc("b", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); CHECK_EQ_OR_RETURN(a.shape().NumAxes(), b.shape().NumAxes()); for (int i = 0; i < a.shape().NumAxes() - 1; ++i) { diff --git a/oneflow/user/ops/max_pool_op.cpp b/oneflow/user/ops/max_pool_op.cpp index 53c5573f2a6..8fe4d43d727 100644 --- a/oneflow/user/ops/max_pool_op.cpp +++ b/oneflow/user/ops/max_pool_op.cpp @@ -47,12 +47,12 @@ TensorDescInferFn MaxPoolMakeForwardTensorDescInferFn(const int32_t dim) { const MaxPoolParams3D params_3d(dim, x_shape, data_format, padding, kernel_size, stride, dilation, return_indices, ceil_mode); - user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); + user_op::TensorDesc* y_desc = ctx->MutOutputTensorDesc("y", 0); *y_desc = ctx->InputTensorDesc("x", 0); *y_desc->mut_shape() = params_3d.GetYShape(); - user_op::TensorDesc* indice_desc = ctx->OutputTensorDesc("indice", 0); - *indice_desc = *ctx->OutputTensorDesc("y", 0); + user_op::TensorDesc* indice_desc = ctx->MutOutputTensorDesc("indice", 0); + *indice_desc = *ctx->MutOutputTensorDesc("y", 0); *indice_desc->mut_shape() = *y_desc->mut_shape(); DataType* dtype = indice_desc->mut_data_type(); *dtype = kInt64; @@ -111,7 +111,7 @@ GenBackwardOpConfFn MaxPoolMakeBackwardOpConfFn(const int32_t dim) { } Maybe BackwardTensorDescInferFn(user_op::InferContext* ctx) { - *ctx->OutputTensorDesc("dx", 0) = ctx->InputTensorDesc("x", 0); + *ctx->MutOutputTensorDesc("dx", 0) = ctx->InputTensorDesc("x", 0); return Maybe::Ok(); } diff --git a/oneflow/user/ops/mutable_cast_once_op.cpp b/oneflow/user/ops/mutable_cast_once_op.cpp index 3c707cb262d..a9ee5719c64 100644 --- a/oneflow/user/ops/mutable_cast_once_op.cpp +++ b/oneflow/user/ops/mutable_cast_once_op.cpp @@ -20,7 +20,7 @@ namespace oneflow { /* static */ Maybe MutableCastOnceOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& input_tensor_desc = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* output_tensor_desc = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* output_tensor_desc = ctx->MutOutputTensorDesc("out", 0); *output_tensor_desc->mut_shape() = input_tensor_desc.shape(); *output_tensor_desc->mut_is_dynamic() = input_tensor_desc.is_dynamic(); return Maybe::Ok(); @@ -40,7 +40,7 @@ namespace oneflow { } /* static */ Maybe MutableCastOnceOp::InferDataType(user_op::InferContext* ctx) { - user_op::TensorDesc* output_tensor_desc = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* output_tensor_desc = ctx->MutOutputTensorDesc("out", 0); DataType* dtype = output_tensor_desc->mut_data_type(); *dtype = ctx->Attr("dtype"); return Maybe::Ok(); diff --git a/oneflow/user/ops/narrow_op.cpp b/oneflow/user/ops/narrow_op.cpp index da0a34f218b..99572be22c3 100644 --- a/oneflow/user/ops/narrow_op.cpp +++ b/oneflow/user/ops/narrow_op.cpp @@ -29,7 +29,7 @@ namespace oneflow { CHECK_GE_OR_RETURN(length, 0); // length should be input size if split the full slice dimension if (start == 0 && length > in.shape().At(dim)) { length = in.shape().At(dim); } - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); DimVector dim_vec; dim_vec.insert(dim_vec.end(), in.shape().dim_vec().cbegin(), in.shape().dim_vec().cbegin() + dim); @@ -72,7 +72,7 @@ namespace oneflow { /* static */ Maybe NarrowOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); *out->mut_data_type() = in.data_type(); return Maybe::Ok(); } diff --git a/oneflow/user/ops/nll_op.cpp b/oneflow/user/ops/nll_op.cpp index b7a56773e9f..894bcfc89da 100644 --- a/oneflow/user/ops/nll_op.cpp +++ b/oneflow/user/ops/nll_op.cpp @@ -61,11 +61,11 @@ namespace oneflow { << weight_desc.shape().ToString(); } - user_op::TensorDesc* output_desc = ctx->OutputTensorDesc("output", 0); + user_op::TensorDesc* output_desc = ctx->MutOutputTensorDesc("output", 0); *output_desc->mut_is_dynamic() = is_dynamic; *output_desc->mut_shape() = Shape({N}); - user_op::TensorDesc* out_weight_desc = ctx->OutputTensorDesc("out_weight", 0); + user_op::TensorDesc* out_weight_desc = ctx->MutOutputTensorDesc("out_weight", 0); *out_weight_desc->mut_is_dynamic() = is_dynamic; *out_weight_desc->mut_shape() = Shape({N}); @@ -159,7 +159,7 @@ namespace oneflow { << weight_desc.shape().ToString(); } - user_op::TensorDesc* in_grad_desc = ctx->OutputTensorDesc("in_grad", 0); + user_op::TensorDesc* in_grad_desc = ctx->MutOutputTensorDesc("in_grad", 0); *in_grad_desc->mut_is_dynamic() = is_dynamic; *in_grad_desc->mut_shape() = input_desc.shape(); diff --git a/oneflow/user/ops/normalization_op.cpp b/oneflow/user/ops/normalization_op.cpp index 8d0db30cffd..4799eca4a87 100644 --- a/oneflow/user/ops/normalization_op.cpp +++ b/oneflow/user/ops/normalization_op.cpp @@ -50,7 +50,7 @@ std::function(const std::string&)> MakeSetParamTensorDescFn(user_op: const Shape& shape) { return [=](const std::string& bn) -> Maybe { if (ctx->has_output(bn, 0)) { - auto* tensor_desc = ctx->OutputTensorDesc(bn, 0); + auto* tensor_desc = ctx->MutOutputTensorDesc(bn, 0); CHECK_OR_RETURN(tensor_desc != nullptr); *tensor_desc->mut_shape() = shape; } @@ -62,7 +62,7 @@ std::function(const std::string&)> MakeSetParamDataTypeFn(user_op::I DataType data_type) { return [=](const std::string& bn) -> Maybe { if (ctx->has_output(bn, 0)) { - auto* tensor_desc = ctx->OutputTensorDesc(bn, 0); + auto* tensor_desc = ctx->MutOutputTensorDesc(bn, 0); CHECK_OR_RETURN(tensor_desc != nullptr); *tensor_desc->mut_data_type() = data_type; } @@ -141,7 +141,7 @@ user_op::TensorDescInferFn MakeFwTensorDescInferFn( CHECK_EQ_OR_RETURN(add_to_output.data_type(), data_type); CHECK_EQ_OR_RETURN(add_to_output.shape(), x_shape); } - *ctx->OutputTensorDesc("y", 0) = x; + *ctx->MutOutputTensorDesc("y", 0) = x; const auto axis = ctx->Attr("axis"); CHECK_GE_OR_RETURN(axis, 0); CHECK_LT_OR_RETURN(axis, x_shape.NumAxes()); @@ -159,7 +159,7 @@ user_op::TensorDescInferFn MakeFwTensorDescInferFn( JUST(SetParamTensorDesc("inv_variance")); if (ctx->has_output("reserve_space", 0)) { CHECK_OR_RETURN(reserve_space_infer_fn); - reserve_space_infer_fn(ctx, &x, ctx->OutputTensorDesc("reserve_space", 0)); + reserve_space_infer_fn(ctx, &x, ctx->MutOutputTensorDesc("reserve_space", 0)); } return Maybe::Ok(); }; @@ -179,7 +179,7 @@ user_op::DataTypeInferFn MakeFwDataTypeInferFn( const auto& add_to_output = ctx->InputTensorDesc("_add_to_output", 0); CHECK_EQ_OR_RETURN(add_to_output.data_type(), data_type); } - *ctx->OutputTensorDesc("y", 0) = x; + *ctx->MutOutputTensorDesc("y", 0) = x; const DataType param_data_type = data_type == DataType::kFloat16 ? DataType::kFloat : data_type; const auto CheckParamDataType = MakeCheckParamDataTypeFn(ctx, param_data_type); const auto SetParamDataType = MakeSetParamDataTypeFn(ctx, param_data_type); @@ -195,7 +195,7 @@ user_op::DataTypeInferFn MakeFwDataTypeInferFn( JUST(SetParamDataType("inv_variance")); if (ctx->has_output("reserve_space", 0)) { CHECK_OR_RETURN(reserve_space_infer_fn); - reserve_space_infer_fn(ctx, &x, ctx->OutputTensorDesc("reserve_space", 0)); + reserve_space_infer_fn(ctx, &x, ctx->MutOutputTensorDesc("reserve_space", 0)); } return Maybe::Ok(); }; @@ -435,8 +435,8 @@ Maybe BwTensorDescInferFn(user_op::InferContext* ctx) { const user_op::TensorDesc& y = ctx->InputTensorDesc("y", 0); CHECK_EQ_OR_RETURN(y.shape(), x_shape); } - *ctx->OutputTensorDesc("dx", 0) = x; - if (ctx->has_output("addend_diff", 0)) { *ctx->OutputTensorDesc("addend_diff", 0) = x; } + *ctx->MutOutputTensorDesc("dx", 0) = x; + if (ctx->has_output("addend_diff", 0)) { *ctx->MutOutputTensorDesc("addend_diff", 0) = x; } const Shape param_shape({x_shape.At(ctx->Attr("axis"))}); const auto CheckParamTensorDesc = MakeCheckParamTensorDescFn(ctx, param_shape); const auto SetParamTensorDesc = MakeSetParamTensorDescFn(ctx, param_shape); @@ -458,8 +458,8 @@ Maybe BwDataTypeInferFn(user_op::InferContext* ctx) { const user_op::TensorDesc& y = ctx->InputTensorDesc("y", 0); CHECK_EQ_OR_RETURN(y.data_type(), x_type); } - *ctx->OutputTensorDesc("dx", 0) = x; - if (ctx->has_output("addend_diff", 0)) { *ctx->OutputTensorDesc("addend_diff", 0) = x; } + *ctx->MutOutputTensorDesc("dx", 0) = x; + if (ctx->has_output("addend_diff", 0)) { *ctx->MutOutputTensorDesc("addend_diff", 0) = x; } const DataType param_data_type = x_type == DataType::kFloat16 ? DataType::kFloat : x_type; const auto CheckParamDataType = MakeCheckParamDataTypeFn(ctx, param_data_type); const auto SetParamDataType = MakeSetParamDataTypeFn(ctx, param_data_type); diff --git a/oneflow/user/ops/ofrecord_decoder_ops.cpp b/oneflow/user/ops/ofrecord_decoder_ops.cpp index 02ccf542062..f03ce2d09c7 100644 --- a/oneflow/user/ops/ofrecord_decoder_ops.cpp +++ b/oneflow/user/ops/ofrecord_decoder_ops.cpp @@ -21,7 +21,7 @@ namespace oneflow { /* static */ Maybe OfrecordRawDecoderOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc("out", 0); CHECK_OR_RETURN(in_tensor.shape().NumAxes() == 1 && in_tensor.shape().At(0) >= 1); Shape conf_shape = ctx->Attr("shape"); DimVector dim_vec(1 + conf_shape.NumAxes()); @@ -50,7 +50,7 @@ namespace oneflow { /* static */ Maybe OfrecordRawDecoderOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc("out", 0); CHECK_OR_RETURN(in_tensor.data_type() == DataType::kOFRecord); *out_tensor->mut_data_type() = ctx->Attr("data_type"); return Maybe::Ok(); @@ -59,7 +59,7 @@ namespace oneflow { /* static */ Maybe OfrecordBytesDecoderOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); *out->mut_is_dynamic() = in.is_dynamic(); *out->mut_shape() = in.shape(); return Maybe::Ok(); @@ -83,7 +83,7 @@ namespace oneflow { /* static */ Maybe OfrecordBytesDecoderOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); CHECK_OR_RETURN(in.data_type() == DataType::kOFRecord); *out->mut_data_type() = DataType::kTensorBuffer; return Maybe::Ok(); @@ -92,7 +92,7 @@ namespace oneflow { /* static */ Maybe OfrecordImageDecoderOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc("out", 0); CHECK_OR_RETURN(in_tensor.shape().NumAxes() == 1 && in_tensor.shape().At(0) >= 1); *out_tensor->mut_shape() = in_tensor.shape(); return Maybe::Ok(); @@ -117,7 +117,7 @@ namespace oneflow { /* static */ Maybe OfrecordImageDecoderOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc("out", 0); CHECK_OR_RETURN(in_tensor.data_type() == DataType::kOFRecord); *out_tensor->mut_data_type() = DataType::kTensorBuffer; return Maybe::Ok(); @@ -126,7 +126,7 @@ namespace oneflow { /* static */ Maybe OfrecordImageDecoderRandomCropOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc("out", 0); CHECK_OR_RETURN(in_tensor.shape().NumAxes() == 1 && in_tensor.shape().At(0) >= 1); *out_tensor->mut_shape() = in_tensor.shape(); return Maybe::Ok(); @@ -153,7 +153,7 @@ namespace oneflow { /* static */ Maybe OfrecordImageDecoderRandomCropOp::InferDataType( user_op::InferContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc("out", 0); CHECK_OR_RETURN(in_tensor.data_type() == DataType::kOFRecord); *out_tensor->mut_data_type() = DataType::kTensorBuffer; return Maybe::Ok(); diff --git a/oneflow/user/ops/ofrecord_image_classification_reader_op.cpp b/oneflow/user/ops/ofrecord_image_classification_reader_op.cpp index 9b683de5a1f..801afd7a295 100644 --- a/oneflow/user/ops/ofrecord_image_classification_reader_op.cpp +++ b/oneflow/user/ops/ofrecord_image_classification_reader_op.cpp @@ -20,8 +20,8 @@ namespace oneflow { /* static */ Maybe OfrecordImageClassificationReaderOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { - user_op::TensorDesc* image_tensor = ctx->OutputTensorDesc("image", 0); - user_op::TensorDesc* label_tensor = ctx->OutputTensorDesc("label", 0); + user_op::TensorDesc* image_tensor = ctx->MutOutputTensorDesc("image", 0); + user_op::TensorDesc* label_tensor = ctx->MutOutputTensorDesc("label", 0); int32_t batch_size = ctx->Attr("batch_size"); *image_tensor->mut_shape() = Shape({batch_size}); *label_tensor->mut_shape() = Shape({batch_size}); @@ -30,8 +30,8 @@ namespace oneflow { /* static */ Maybe OfrecordImageClassificationReaderOp::InferPhysicalTensorDesc( user_op::InferContext* ctx) { - user_op::TensorDesc* image_tensor = ctx->OutputTensorDesc("image", 0); - user_op::TensorDesc* label_tensor = ctx->OutputTensorDesc("label", 0); + user_op::TensorDesc* image_tensor = ctx->MutOutputTensorDesc("image", 0); + user_op::TensorDesc* label_tensor = ctx->MutOutputTensorDesc("label", 0); int32_t local_batch_size = ctx->Attr("batch_size"); int64_t parallel_num = ctx->parallel_ctx().parallel_num(); diff --git a/oneflow/user/ops/ofrecord_reader_op.cpp b/oneflow/user/ops/ofrecord_reader_op.cpp index a43a08015a7..099c84fd508 100644 --- a/oneflow/user/ops/ofrecord_reader_op.cpp +++ b/oneflow/user/ops/ofrecord_reader_op.cpp @@ -19,13 +19,13 @@ limitations under the License. namespace oneflow { /* static */ Maybe OFRecordReaderOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc("out", 0); *out_tensor->mut_shape() = Shape({ctx->Attr("batch_size")}); return Maybe::Ok(); } /* static */ Maybe OFRecordReaderOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc("out", 0); int32_t batch_size = ctx->Attr("batch_size"); int64_t parallel_num = ctx->parallel_ctx().parallel_num(); if (parallel_num > 1) { diff --git a/oneflow/user/ops/one_hot_op.cpp b/oneflow/user/ops/one_hot_op.cpp index 0928eeb3cbf..33b73e8957a 100644 --- a/oneflow/user/ops/one_hot_op.cpp +++ b/oneflow/user/ops/one_hot_op.cpp @@ -26,7 +26,7 @@ namespace oneflow { // For 0-dim Tensor CHECK_GE_OR_RETURN(indices_desc.shape().NumAxes(), 0) << "indices dim must be great or equal than 0"; - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); *out_desc->mut_is_dynamic() = indices_desc.is_dynamic(); DimVector dim_vec = indices_desc.shape().dim_vec(); dim_vec.emplace_back(depth); @@ -62,7 +62,7 @@ namespace oneflow { /* static */ Maybe OneHotOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& indices_desc = ctx->InputTensorDesc("indices", 0); CHECK_OR_RETURN(IsIndexDataType(indices_desc.data_type())); - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); DataType dtype = ctx->Attr("dtype"); *out_desc->mut_data_type() = dtype; return Maybe::Ok(); diff --git a/oneflow/user/ops/onerec_decoder_op.cpp b/oneflow/user/ops/onerec_decoder_op.cpp index 8e00a20f345..de97c5e435a 100644 --- a/oneflow/user/ops/onerec_decoder_op.cpp +++ b/oneflow/user/ops/onerec_decoder_op.cpp @@ -20,7 +20,7 @@ namespace oneflow { /* static */ Maybe OnerecDecoderOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc("out", 0); CHECK_OR_RETURN(in_tensor.shape().NumAxes() == 1 && in_tensor.shape().At(0) >= 1); const Shape& static_shape = ctx->Attr("static_shape"); DimVector dim_vec(1 + static_shape.NumAxes()); @@ -65,7 +65,7 @@ namespace oneflow { /* static */ Maybe OnerecDecoderOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc("out", 0); CHECK_OR_RETURN(in_tensor.data_type() == DataType::kTensorBuffer); *out_tensor->mut_data_type() = ctx->Attr("data_type"); return Maybe::Ok(); diff --git a/oneflow/user/ops/onerec_reader_op.cpp b/oneflow/user/ops/onerec_reader_op.cpp index 95b34f8dbf4..76bffcb467e 100644 --- a/oneflow/user/ops/onerec_reader_op.cpp +++ b/oneflow/user/ops/onerec_reader_op.cpp @@ -19,7 +19,7 @@ limitations under the License. namespace oneflow { /*static*/ Maybe OneRecReaderOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc("out", 0); int64_t batch_size = ctx->Attr("batch_size"); *out_tensor->mut_shape() = Shape({batch_size}); return Maybe::Ok(); diff --git a/oneflow/user/ops/pack_op.cpp b/oneflow/user/ops/pack_op.cpp index 828192e77e2..b0d1fa55745 100644 --- a/oneflow/user/ops/pack_op.cpp +++ b/oneflow/user/ops/pack_op.cpp @@ -34,7 +34,7 @@ namespace oneflow { const Shape& in_shape = in_desc.shape(); const int32_t pack_num = ctx->Attr("pack_num"); CHECK_GT_OR_RETURN(pack_num, 0); - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); *out_desc->mut_is_dynamic() = in_desc.is_dynamic(); if (in_shape.NumAxes() > 0) { *out_desc->mut_shape() = in_shape; diff --git a/oneflow/user/ops/partial_fc_sample_op.cpp b/oneflow/user/ops/partial_fc_sample_op.cpp index e7e73f3f05d..ee60da401b6 100644 --- a/oneflow/user/ops/partial_fc_sample_op.cpp +++ b/oneflow/user/ops/partial_fc_sample_op.cpp @@ -33,9 +33,9 @@ namespace oneflow { const int64_t num_sample = ctx->Attr("num_sample"); const user_op::TensorDesc& weight = ctx->InputTensorDesc("weight", 0); const user_op::TensorDesc& label = ctx->InputTensorDesc("label", 0); - user_op::TensorDesc* mapped_label = ctx->OutputTensorDesc("mapped_label", 0); - user_op::TensorDesc* sampled_weight = ctx->OutputTensorDesc("sampled_weight", 0); - user_op::TensorDesc* sampled_label = ctx->OutputTensorDesc("sampled_label", 0); + user_op::TensorDesc* mapped_label = ctx->MutOutputTensorDesc("mapped_label", 0); + user_op::TensorDesc* sampled_weight = ctx->MutOutputTensorDesc("sampled_weight", 0); + user_op::TensorDesc* sampled_label = ctx->MutOutputTensorDesc("sampled_label", 0); *mapped_label->mut_shape() = label.shape(); *mapped_label->mut_is_dynamic() = label.is_dynamic(); *sampled_weight->mut_shape() = weight.shape(); @@ -54,9 +54,9 @@ namespace oneflow { const int64_t num_sample_per_rank = num_sample / parallel_num; const user_op::TensorDesc& weight = ctx->InputTensorDesc("weight", 0); const user_op::TensorDesc& label = ctx->InputTensorDesc("label", 0); - user_op::TensorDesc* mapped_label = ctx->OutputTensorDesc("mapped_label", 0); - user_op::TensorDesc* sampled_weight = ctx->OutputTensorDesc("sampled_weight", 0); - user_op::TensorDesc* sampled_label = ctx->OutputTensorDesc("sampled_label", 0); + user_op::TensorDesc* mapped_label = ctx->MutOutputTensorDesc("mapped_label", 0); + user_op::TensorDesc* sampled_weight = ctx->MutOutputTensorDesc("sampled_weight", 0); + user_op::TensorDesc* sampled_label = ctx->MutOutputTensorDesc("sampled_label", 0); *mapped_label->mut_shape() = label.shape(); *mapped_label->mut_is_dynamic() = label.is_dynamic(); *sampled_weight->mut_shape() = weight.shape(); @@ -93,7 +93,7 @@ namespace oneflow { /*static*/ Maybe DistributedPartialFcSampleDisableBoxingOp::InferLogicalTensorDesc( user_op::InferContext* ctx) { user_op::TensorDesc* boxing_disabled_sampled_weight_diff = - ctx->OutputTensorDesc("boxing_disabled_sampled_weight_diff", 0); + ctx->MutOutputTensorDesc("boxing_disabled_sampled_weight_diff", 0); *boxing_disabled_sampled_weight_diff->mut_shape() = ctx->InputShape("sampled_weight_diff", 0); CHECK_EQ_OR_RETURN(boxing_disabled_sampled_weight_diff->shape().At(0) % ctx->parallel_num(), 0); boxing_disabled_sampled_weight_diff->mut_shape()->Set( @@ -101,7 +101,7 @@ namespace oneflow { *boxing_disabled_sampled_weight_diff->mut_is_dynamic() = ctx->InputIsDynamic("sampled_weight_diff", 0); user_op::TensorDesc* boxing_disabled_sampled_label = - ctx->OutputTensorDesc("boxing_disabled_sampled_label", 0); + ctx->MutOutputTensorDesc("boxing_disabled_sampled_label", 0); *boxing_disabled_sampled_label->mut_shape() = ctx->InputShape("sampled_label", 0); CHECK_EQ_OR_RETURN(boxing_disabled_sampled_label->shape().At(0) % ctx->parallel_num(), 0); boxing_disabled_sampled_label->mut_shape()->Set( diff --git a/oneflow/user/ops/reduce_like_ops.cpp b/oneflow/user/ops/reduce_like_ops.cpp index 381c0c52ccc..8dabbef1bbd 100644 --- a/oneflow/user/ops/reduce_like_ops.cpp +++ b/oneflow/user/ops/reduce_like_ops.cpp @@ -80,7 +80,7 @@ namespace oneflow { << " when the input axis list is empty"; } - user_op::TensorDesc* y_tensor = ctx->OutputTensorDesc("y", 0); + user_op::TensorDesc* y_tensor = ctx->MutOutputTensorDesc("y", 0); *y_tensor->mut_shape() = like_tensor.shape(); *y_tensor->mut_is_dynamic() = like_tensor.is_dynamic(); return Maybe::Ok(); diff --git a/oneflow/user/ops/repeat_interleave_op.cpp b/oneflow/user/ops/repeat_interleave_op.cpp index 22742f9cb2f..e76fdd859f6 100644 --- a/oneflow/user/ops/repeat_interleave_op.cpp +++ b/oneflow/user/ops/repeat_interleave_op.cpp @@ -36,7 +36,7 @@ namespace oneflow { } /*static*/ Maybe Repeat_InterLeaveOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const int64_t repeat_num = ctx->Attr("repeat_num"); - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); *out_desc->mut_shape() = Shape({repeat_num}); return Maybe::Ok(); } diff --git a/oneflow/user/ops/reshape_op.cpp b/oneflow/user/ops/reshape_op.cpp index 77b4fc35c3d..85dbcfabc08 100644 --- a/oneflow/user/ops/reshape_op.cpp +++ b/oneflow/user/ops/reshape_op.cpp @@ -33,7 +33,7 @@ namespace oneflow { /*static*/ Maybe ReshapeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { Shape shape = ctx->Attr("shape"); const user_op::TensorDesc& in_tensor_desc = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out_tensor_desc = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out_tensor_desc = ctx->MutOutputTensorDesc("out", 0); const Shape& in_shape = in_tensor_desc.shape(); Shape* out_shape = out_tensor_desc->mut_shape(); @@ -76,7 +76,7 @@ namespace oneflow { /*static*/ Maybe ReshapeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { Shape logical_shape = ctx->Attr("shape"); const user_op::TensorDesc& in_tensor_desc = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out_tensor_desc = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out_tensor_desc = ctx->MutOutputTensorDesc("out", 0); const Shape& in_shape = in_tensor_desc.shape(); Shape* out_shape = out_tensor_desc->mut_shape(); diff --git a/oneflow/user/ops/roc_auc_score_op.cpp b/oneflow/user/ops/roc_auc_score_op.cpp index 19c428dae90..eb3161e102e 100644 --- a/oneflow/user/ops/roc_auc_score_op.cpp +++ b/oneflow/user/ops/roc_auc_score_op.cpp @@ -19,7 +19,7 @@ limitations under the License. namespace oneflow { /* static */ Maybe RocAucScoreOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); const Shape& pred_shape = ctx->InputTensorDesc("pred", 0).shape(); const Shape& label_shape = ctx->InputTensorDesc("label", 0).shape(); CHECK_EQ_OR_RETURN(pred_shape.elem_cnt(), label_shape.elem_cnt()) diff --git a/oneflow/user/ops/same_padding_op.cpp b/oneflow/user/ops/same_padding_op.cpp index 29cc988765f..61d5c944243 100644 --- a/oneflow/user/ops/same_padding_op.cpp +++ b/oneflow/user/ops/same_padding_op.cpp @@ -35,7 +35,7 @@ namespace oneflow { } /*static*/ Maybe SamePaddingOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); - user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); + user_op::TensorDesc* y_desc = ctx->MutOutputTensorDesc("y", 0); *y_desc->mut_shape() = x_desc.shape(); *y_desc->mut_is_dynamic() = x_desc.is_dynamic(); const std::string& data_format = ctx->Attr("data_format"); diff --git a/oneflow/user/ops/scalar_by_tensor_op.cpp b/oneflow/user/ops/scalar_by_tensor_op.cpp index 1f787e2c96b..691a3c82222 100644 --- a/oneflow/user/ops/scalar_by_tensor_op.cpp +++ b/oneflow/user/ops/scalar_by_tensor_op.cpp @@ -25,7 +25,7 @@ Maybe TensorDescInferFn(user_op::InferContext* ctx) { const user_op::TensorDesc& scalar = ctx->InputTensorDesc("scalar", 0); CHECK_EQ_OR_RETURN(scalar.shape().elem_cnt(), 1) << Error::RuntimeError() << "The input scalar tensor is not a scalar"; - user_op::TensorDesc* y = ctx->OutputTensorDesc("y", 0); + user_op::TensorDesc* y = ctx->MutOutputTensorDesc("y", 0); *y->mut_shape() = x.shape(); *y->mut_is_dynamic() = x.is_dynamic(); return Maybe::Ok(); @@ -36,7 +36,7 @@ Maybe DataTypeInferFn(user_op::InferContext* ctx) { const user_op::TensorDesc& scalar = ctx->InputTensorDesc("scalar", 0); CHECK_EQ_OR_RETURN(x.data_type(), scalar.data_type()) << Error::TypeError() << "Tensors x and scalar have different type"; - user_op::TensorDesc* y = ctx->OutputTensorDesc("y", 0); + user_op::TensorDesc* y = ctx->MutOutputTensorDesc("y", 0); *y->mut_data_type() = x.data_type(); return Maybe::Ok(); } diff --git a/oneflow/user/ops/sigmoid_cross_entropy_op.cpp b/oneflow/user/ops/sigmoid_cross_entropy_op.cpp index 2221d06017a..dc447f9f9f4 100644 --- a/oneflow/user/ops/sigmoid_cross_entropy_op.cpp +++ b/oneflow/user/ops/sigmoid_cross_entropy_op.cpp @@ -36,7 +36,7 @@ namespace oneflow { CHECK_EQ_OR_RETURN(label_desc.shape(), prediction_desc.shape()) << Error::RuntimeError() << "The size of label " << label_desc.shape() << " must match the size of prediction " << prediction_desc.shape(); - user_op::TensorDesc* loss_desc = ctx->OutputTensorDesc("loss", 0); + user_op::TensorDesc* loss_desc = ctx->MutOutputTensorDesc("loss", 0); *loss_desc->mut_shape() = prediction_desc.shape(); *loss_desc->mut_is_dynamic() = prediction_desc.is_dynamic(); return Maybe::Ok(); @@ -79,7 +79,7 @@ namespace oneflow { CHECK_EQ_OR_RETURN(loss_diff_desc.shape(), prediction_desc.shape()) << Error::RuntimeError() << "The size of loss_diff " << loss_diff_desc.shape() << " must match the size of prediction " << prediction_desc.shape(); - user_op::TensorDesc* prediction_diff = ctx->OutputTensorDesc("prediction_diff", 0); + user_op::TensorDesc* prediction_diff = ctx->MutOutputTensorDesc("prediction_diff", 0); *prediction_diff->mut_shape() = prediction_desc.shape(); *prediction_diff->mut_is_dynamic() = prediction_desc.is_dynamic(); return Maybe::Ok(); diff --git a/oneflow/user/ops/slice_op.cpp b/oneflow/user/ops/slice_op.cpp index d5531ee1692..9513f262f11 100644 --- a/oneflow/user/ops/slice_op.cpp +++ b/oneflow/user/ops/slice_op.cpp @@ -98,7 +98,7 @@ bool IsFullSlice(int64_t start, int64_t stop, int64_t step, int64_t size) { << "The size of slice tuple must be equal to the size of value tensor at dimension " << i << ", but got " << (stop - start + step - 1) / step << " and " << value_shape.At(i); } - auto* y_desc = ctx->OutputTensorDesc("y", 0); + auto* y_desc = ctx->MutOutputTensorDesc("y", 0); *y_desc->mut_shape() = ref_desc.shape(); *y_desc->mut_is_dynamic() = ref_desc.is_dynamic(); return Maybe::Ok(); @@ -111,7 +111,7 @@ bool IsFullSlice(int64_t start, int64_t stop, int64_t step, int64_t size) { const user_op::TensorDesc& value_desc = ctx->InputTensorDesc("value", 0); CHECK_OR_RETURN(ref_desc.data_type() == value_desc.data_type()) << Error::TypeError() << "Tensors ref and value must have same type"; - auto* y_desc = ctx->OutputTensorDesc("y", 0); + auto* y_desc = ctx->MutOutputTensorDesc("y", 0); *y_desc->mut_data_type() = ref_desc.data_type(); return Maybe::Ok(); } @@ -259,7 +259,7 @@ bool IsFullSlice(int64_t start, int64_t stop, int64_t step, int64_t size) { /*static*/ Maybe SliceGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { Shape logical_shape = ctx->Attr("like_shape"); const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); - user_op::TensorDesc* dx_desc = ctx->OutputTensorDesc("dx", 0); + user_op::TensorDesc* dx_desc = ctx->MutOutputTensorDesc("dx", 0); *dx_desc->mut_is_dynamic() = dy_desc.is_dynamic(); const auto& nd_sbp = ctx->NdSbp4ArgNameAndIndex("dx", 0); diff --git a/oneflow/user/ops/smooth_l1_loss_op.cpp b/oneflow/user/ops/smooth_l1_loss_op.cpp index 85859963ae7..538c1f57b2a 100644 --- a/oneflow/user/ops/smooth_l1_loss_op.cpp +++ b/oneflow/user/ops/smooth_l1_loss_op.cpp @@ -40,7 +40,7 @@ namespace oneflow { << Error::RuntimeError() << "beta must be greater than or equal to 0, but found it to be " << ctx->Attr("beta"); - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); *out_desc->mut_is_dynamic() = input_desc.is_dynamic(); *out_desc->mut_shape() = input_desc.shape(); @@ -99,7 +99,7 @@ namespace oneflow { << Error::RuntimeError() << "beta must be greater than or equal to 0, but found it to be " << ctx->Attr("beta"); - user_op::TensorDesc* dx_desc = ctx->OutputTensorDesc("dx", 0); + user_op::TensorDesc* dx_desc = ctx->MutOutputTensorDesc("dx", 0); *dx_desc->mut_is_dynamic() = input_desc.is_dynamic(); *dx_desc->mut_shape() = input_desc.shape(); diff --git a/oneflow/user/ops/softmax_cross_entropy_op.cpp b/oneflow/user/ops/softmax_cross_entropy_op.cpp index df39a5b737f..3979ce57f85 100644 --- a/oneflow/user/ops/softmax_cross_entropy_op.cpp +++ b/oneflow/user/ops/softmax_cross_entropy_op.cpp @@ -53,7 +53,7 @@ namespace oneflow { } *ctx->MutOutputShape("prob", 0) = ctx->InputShape("prediction", 0); *ctx->MutOutputIsDynamic("prob", 0) = ctx->InputIsDynamic("prediction", 0); - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); *out_desc->mut_is_dynamic() = prediction_desc.is_dynamic(); *out_desc->mut_shape() = Shape(out_dim_vector); return Maybe::Ok(); @@ -70,7 +70,7 @@ namespace oneflow { << DataType_Name(label_desc.data_type()) << " and " << DataType_Name(prediction_desc.data_type()); *ctx->MutOutputDType("prob", 0) = ctx->InputDType("prediction", 0); - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); *out_desc->mut_data_type() = prediction_desc.data_type(); return Maybe::Ok(); } diff --git a/oneflow/user/ops/sparse_cross_entropy_op.cpp b/oneflow/user/ops/sparse_cross_entropy_op.cpp index 28f7b8f9d11..9c6c3e03332 100644 --- a/oneflow/user/ops/sparse_cross_entropy_op.cpp +++ b/oneflow/user/ops/sparse_cross_entropy_op.cpp @@ -48,7 +48,7 @@ Maybe InferTensorDescFn(user_op::InferContext* ctx) { const user_op::TensorDesc& prediction_desc = ctx->InputTensorDesc("prediction", 0); const user_op::TensorDesc& label_desc = ctx->InputTensorDesc("label", 0); JUST(CheckPredictionLabelDesc(&prediction_desc, &label_desc)); - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); *out_desc->mut_is_dynamic() = prediction_desc.is_dynamic(); *out_desc->mut_shape() = label_desc.shape(); return Maybe::Ok(); @@ -73,7 +73,7 @@ Maybe InferDataType(user_op::InferContext* ctx) { CHECK_OR_RETURN(IsIndexDataType(label_desc.data_type())) << Error::TypeError() << "The dtype of label must be integer, but found " << DataType_Name(label_desc.data_type()); - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); *out_desc->mut_data_type() = prediction_desc.data_type(); return Maybe::Ok(); } diff --git a/oneflow/user/ops/sparse_softmax_cross_entropy_op.cpp b/oneflow/user/ops/sparse_softmax_cross_entropy_op.cpp index a915311bc72..923a2b1217d 100644 --- a/oneflow/user/ops/sparse_softmax_cross_entropy_op.cpp +++ b/oneflow/user/ops/sparse_softmax_cross_entropy_op.cpp @@ -44,7 +44,7 @@ Maybe InferTensorDescFn(user_op::InferContext* ctx) { *ctx->MutOutputIsDynamic("prob", 0) = prediction_desc.is_dynamic(); // 'prob' is just for compute prediction's grad, prob's grad will be ignored *ctx->MutOutputShape("prob", 0) = prediction_desc.shape(); - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); *out_desc->mut_is_dynamic() = prediction_desc.is_dynamic(); *out_desc->mut_shape() = label_desc.shape(); return Maybe::Ok(); diff --git a/oneflow/user/ops/split_like_op.cpp b/oneflow/user/ops/split_like_op.cpp index 816bb38fb65..db3fcc329ce 100644 --- a/oneflow/user/ops/split_like_op.cpp +++ b/oneflow/user/ops/split_like_op.cpp @@ -76,7 +76,7 @@ namespace oneflow { << ") should be less than the dimension of like (" << like_num_axes << ")"; FOR_RANGE(int32_t, i, 0, ctx->outputs().size()) { const user_op::TensorDesc& like_i_desc = ctx->InputTensorDesc("like", i); - user_op::TensorDesc* out_i_desc = ctx->OutputTensorDesc("out", i); + user_op::TensorDesc* out_i_desc = ctx->MutOutputTensorDesc("out", i); CHECK_EQ_OR_RETURN(like_i_desc.shape().NumAxes(), like_num_axes) << Error::RuntimeError() << "The dimension of like_i (" << like_i_desc.shape().NumAxes() << ") must match the dimension of the first like (" << like_num_axes << ")"; @@ -120,7 +120,7 @@ namespace oneflow { /*static*/ Maybe SplitLikeOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); FOR_RANGE(int32_t, i, 0, ctx->outputs().size()) { - user_op::TensorDesc* out_i_desc = ctx->OutputTensorDesc("out", i); + user_op::TensorDesc* out_i_desc = ctx->MutOutputTensorDesc("out", i); *out_i_desc->mut_data_type() = in_desc.data_type(); } return Maybe::Ok(); diff --git a/oneflow/user/ops/sqrt_square_sum_op.cpp b/oneflow/user/ops/sqrt_square_sum_op.cpp index 4766f0628ec..0dcc906498c 100644 --- a/oneflow/user/ops/sqrt_square_sum_op.cpp +++ b/oneflow/user/ops/sqrt_square_sum_op.cpp @@ -26,7 +26,7 @@ namespace oneflow { return Maybe::Ok(); } /*static*/ Maybe SqrtSquareSumOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - user_op::TensorDesc* y = ctx->OutputTensorDesc("y", 0); + user_op::TensorDesc* y = ctx->MutOutputTensorDesc("y", 0); *y->mut_shape() = Shape({}); return Maybe::Ok(); } diff --git a/oneflow/user/ops/square_sum_op.cpp b/oneflow/user/ops/square_sum_op.cpp index 3748c184770..53e938d810e 100644 --- a/oneflow/user/ops/square_sum_op.cpp +++ b/oneflow/user/ops/square_sum_op.cpp @@ -26,7 +26,7 @@ namespace oneflow { return Maybe::Ok(); } /*static*/ Maybe SquareSumOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - user_op::TensorDesc* y = ctx->OutputTensorDesc("y", 0); + user_op::TensorDesc* y = ctx->MutOutputTensorDesc("y", 0); *y->mut_shape() = Shape({1}); return Maybe::Ok(); } @@ -50,7 +50,7 @@ namespace oneflow { return Maybe::Ok(); } /*static*/ Maybe MultiSquareSumOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - user_op::TensorDesc* y = ctx->OutputTensorDesc("y", 0); + user_op::TensorDesc* y = ctx->MutOutputTensorDesc("y", 0); *y->mut_shape() = Shape({1}); return Maybe::Ok(); } @@ -59,7 +59,7 @@ namespace oneflow { } /*static*/ Maybe MultiSquareSumOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& x_0 = ctx->InputTensorDesc("x", 0); - user_op::TensorDesc* y = ctx->OutputTensorDesc("y", 0); + user_op::TensorDesc* y = ctx->MutOutputTensorDesc("y", 0); for (int64_t i = 1; i < ctx->input_size("x"); ++i) { const user_op::TensorDesc& x_i = ctx->InputTensorDesc("x", i); CHECK_EQ_OR_RETURN(x_i.data_type(), x_0.data_type()) diff --git a/oneflow/user/ops/stack_op.cpp b/oneflow/user/ops/stack_op.cpp index 1dd129081bd..4a69a6df1ed 100644 --- a/oneflow/user/ops/stack_op.cpp +++ b/oneflow/user/ops/stack_op.cpp @@ -85,7 +85,7 @@ Maybe GenGradOp(const user_op::UserOpWrapper& op, const user_op::AddOpFn& } } } - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); const int64_t max_dim_size = ctx->Attr("max_dim_size"); CHECK_LE_OR_RETURN(out_dim_vec.at(axis), max_dim_size) << "The out shape at axis " << axis << " should be less equal to " << max_dim_size; @@ -130,7 +130,7 @@ Maybe GenGradOp(const user_op::UserOpWrapper& op, const user_op::AddOpFn& CHECK_EQ_OR_RETURN(in_desc.data_type(), first_in_desc.data_type()) << "The input's data type should be equal to first input's data type. "; } - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); *out_desc->mut_data_type() = first_in_desc.data_type(); return Maybe::Ok(); } @@ -184,7 +184,7 @@ Maybe GenGradOp(const user_op::UserOpWrapper& op, const user_op::AddOpFn& << "The axis should be less equal than num axes of `like` tensor. "; FOR_RANGE(int32_t, i, 0, ctx->outputs().size()) { const user_op::TensorDesc& like_i_desc = ctx->InputTensorDesc("like", i); - user_op::TensorDesc* out_i_desc = ctx->OutputTensorDesc("out", i); + user_op::TensorDesc* out_i_desc = ctx->MutOutputTensorDesc("out", i); CHECK_EQ_OR_RETURN(like_i_desc.shape().NumAxes(), like_num_axes) << "The num axes of `like` tensor at index " << i << " should be equal to first `like` tensor. "; @@ -230,7 +230,7 @@ Maybe GenGradOp(const user_op::UserOpWrapper& op, const user_op::AddOpFn& /*static*/ Maybe StackGradOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); FOR_RANGE(int32_t, i, 0, ctx->outputs().size()) { - user_op::TensorDesc* out_i_desc = ctx->OutputTensorDesc("out", i); + user_op::TensorDesc* out_i_desc = ctx->MutOutputTensorDesc("out", i); *out_i_desc->mut_data_type() = in_desc.data_type(); } return Maybe::Ok(); diff --git a/oneflow/user/ops/tensor_buffer_ops.cpp b/oneflow/user/ops/tensor_buffer_ops.cpp index 80b1c5c99ff..576e7e50ecb 100644 --- a/oneflow/user/ops/tensor_buffer_ops.cpp +++ b/oneflow/user/ops/tensor_buffer_ops.cpp @@ -27,7 +27,7 @@ namespace oneflow { } /*static*/ Maybe TensorBufferToTensorOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); out->set_is_dynamic(in.is_dynamic()); const auto& instance_shape = ctx->Attr("instance_shape"); DimVector dim_vec; @@ -41,7 +41,7 @@ namespace oneflow { } /*static*/ Maybe TensorBufferToTensorOp::InferDataType(user_op::InferContext* ctx) { const auto data_type = ctx->Attr("dtype"); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); CHECK_OR_RETURN(IsPODDataType(data_type)); *out->mut_data_type() = data_type; return Maybe::Ok(); @@ -61,7 +61,7 @@ namespace oneflow { const Shape& in_shape = in.shape(); const auto& instance_dims = ctx->Attr("instance_dims"); CHECK_LT_OR_RETURN(instance_dims, in_shape.NumAxes()); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); out->set_is_dynamic(in.is_dynamic()); DimVector out_dim_vec; out_dim_vec.insert(out_dim_vec.end(), in_shape.dim_vec().cbegin(), @@ -75,7 +75,7 @@ namespace oneflow { /*static*/ Maybe TensorToTensorBufferOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); CHECK_OR_RETURN(IsPODDataType(in.data_type())); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); *out->mut_data_type() = DataType::kTensorBuffer; return Maybe::Ok(); } @@ -84,7 +84,7 @@ namespace oneflow { return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); } /*static*/ Maybe GenTensorBufferOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); const Shape& shape = ctx->Attr("shape"); const int64_t num_tensor_buffers = shape.elem_cnt(); const std::vector& shape_list = ctx->Attr>("shape_list"); @@ -99,7 +99,7 @@ namespace oneflow { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe GenTensorBufferOp::InferDataType(user_op::InferContext* ctx) { - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); *out->mut_data_type() = DataType::kTensorBuffer; return Maybe::Ok(); } @@ -116,7 +116,7 @@ namespace oneflow { const bool dynamic_out = ctx->Attr("dynamic_out"); int64_t num_tensor_buffers = in.shape().elem_cnt(); for (int64_t i = 0; i < num_tensor_buffers; ++i) { - user_op::TensorDesc* out_i = ctx->OutputTensorDesc("out", i); + user_op::TensorDesc* out_i = ctx->MutOutputTensorDesc("out", i); *out_i->mut_shape() = out_shape; out_i->set_is_dynamic(dynamic_out); } @@ -133,7 +133,7 @@ namespace oneflow { CHECK_OR_RETURN(IsPODDataType(out_dtype)); int64_t num_tensor_buffers = ctx->outputs().size(); for (int64_t i = 0; i < num_tensor_buffers; ++i) { - user_op::TensorDesc* out_i = ctx->OutputTensorDesc("out", i); + user_op::TensorDesc* out_i = ctx->MutOutputTensorDesc("out", i); *out_i->mut_data_type() = out_dtype; } return Maybe::Ok(); @@ -168,7 +168,7 @@ namespace oneflow { const bool dynamic_out = ctx->Attr("dynamic_out"); int64_t num_tensor_buffers = in.shape().elem_cnt(); for (int64_t i = 0; i < num_tensor_buffers; ++i) { - user_op::TensorDesc* out_i = ctx->OutputTensorDesc("out", i); + user_op::TensorDesc* out_i = ctx->MutOutputTensorDesc("out", i); *out_i->mut_shape() = out_shapes[i]; out_i->set_is_dynamic(dynamic_out); } @@ -185,7 +185,7 @@ namespace oneflow { int64_t num_tensor_buffers = ctx->outputs().size(); for (int64_t i = 0; i < num_tensor_buffers; ++i) { CHECK_OR_RETURN(IsPODDataType(out_dtypes[i])); - user_op::TensorDesc* out_i = ctx->OutputTensorDesc("out", i); + user_op::TensorDesc* out_i = ctx->MutOutputTensorDesc("out", i); *out_i->mut_data_type() = out_dtypes[i]; } return Maybe::Ok(); diff --git a/oneflow/user/ops/tf_pool_op.cpp b/oneflow/user/ops/tf_pool_op.cpp index b420141ae12..5904a17bedc 100644 --- a/oneflow/user/ops/tf_pool_op.cpp +++ b/oneflow/user/ops/tf_pool_op.cpp @@ -43,7 +43,7 @@ TensorDescInferFn MakeFwTensorDescInferFn(const int32_t dim) { const Params3D params_3d(dim, x_shape, data_format, padding, padding_before, padding_after, pool_size, strides, ceil_mode); - user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); + user_op::TensorDesc* y_desc = ctx->MutOutputTensorDesc("y", 0); *y_desc->mut_shape() = params_3d.GetYShape(); *y_desc->mut_is_dynamic() = ctx->InputIsDynamic("x", 0); return Maybe::Ok(); diff --git a/oneflow/user/ops/tf_prelu_op.cpp b/oneflow/user/ops/tf_prelu_op.cpp index 543b9940b5d..6a0b981114f 100644 --- a/oneflow/user/ops/tf_prelu_op.cpp +++ b/oneflow/user/ops/tf_prelu_op.cpp @@ -39,7 +39,7 @@ namespace oneflow { } /*static*/ Maybe TfPreluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); - user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); + user_op::TensorDesc* y_desc = ctx->MutOutputTensorDesc("y", 0); const Shape& alpha_shape = ctx->InputShape("alpha", 0); CHECK_EQ_OR_RETURN(x_desc.shape().NumAxes(), alpha_shape.NumAxes() + 1); FOR_RANGE(int64_t, i, 1, x_desc.shape().NumAxes()) { @@ -91,7 +91,7 @@ namespace oneflow { /*static*/ Maybe TfPreluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); - user_op::TensorDesc* dx_desc = ctx->OutputTensorDesc("dx", 0); + user_op::TensorDesc* dx_desc = ctx->MutOutputTensorDesc("dx", 0); const user_op::TensorDesc& alpha_desc = ctx->InputTensorDesc("alpha", 0); CHECK_EQ_OR_RETURN(x_desc.shape().NumAxes(), alpha_desc.shape().NumAxes() + 1); FOR_RANGE(int64_t, i, 1, x_desc.shape().NumAxes()) { diff --git a/oneflow/user/ops/transpose_ops.cpp b/oneflow/user/ops/transpose_ops.cpp index 2b483d8f449..23525e5c0b0 100644 --- a/oneflow/user/ops/transpose_ops.cpp +++ b/oneflow/user/ops/transpose_ops.cpp @@ -44,7 +44,7 @@ void CheckIsPerm(const std::vector& perm) { } /*static*/ Maybe TransposeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& in_tensor_desc = ctx->InputTensorDesc("input", 0); - user_op::TensorDesc* out_tensor_desc = ctx->OutputTensorDesc("output", 0); + user_op::TensorDesc* out_tensor_desc = ctx->MutOutputTensorDesc("output", 0); const Shape& in_shape = in_tensor_desc.shape(); Shape* out_shape = out_tensor_desc->mut_shape(); const auto& perm = ctx->Attr>("perm"); diff --git a/oneflow/user/ops/tril_op.cpp b/oneflow/user/ops/tril_op.cpp index 933727beef0..bbac1ce5ee0 100644 --- a/oneflow/user/ops/tril_op.cpp +++ b/oneflow/user/ops/tril_op.cpp @@ -36,7 +36,7 @@ namespace oneflow { } /*static*/ Maybe TrilOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); CHECK_GE_OR_RETURN(in.shape().NumAxes(), 2); *out->mut_shape() = in.shape(); *out->mut_is_dynamic() = in.is_dynamic(); @@ -47,7 +47,7 @@ namespace oneflow { } /*static*/ Maybe TrilOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); *out->mut_data_type() = in.data_type(); return Maybe::Ok(); } @@ -85,7 +85,7 @@ REGISTER_USER_OP_GRAD("tril").SetGenBackwardOpConfFn([](const user_op::UserOpWra } /*static*/ Maybe FusedScaleTrilOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); CHECK_GE_OR_RETURN(in.shape().NumAxes(), 2); *out->mut_shape() = in.shape(); *out->mut_is_dynamic() = in.is_dynamic(); @@ -96,7 +96,7 @@ REGISTER_USER_OP_GRAD("tril").SetGenBackwardOpConfFn([](const user_op::UserOpWra } /*static*/ Maybe FusedScaleTrilOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); *out->mut_data_type() = in.data_type(); return Maybe::Ok(); } diff --git a/oneflow/user/ops/triu_op.cpp b/oneflow/user/ops/triu_op.cpp index 00448d7f585..606c9e80d3a 100644 --- a/oneflow/user/ops/triu_op.cpp +++ b/oneflow/user/ops/triu_op.cpp @@ -31,7 +31,7 @@ namespace oneflow { } /*static*/ Maybe TriuOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); CHECK_GE_OR_RETURN(in.shape().NumAxes(), 2); *out->mut_shape() = in.shape(); *out->mut_is_dynamic() = in.is_dynamic(); @@ -42,7 +42,7 @@ namespace oneflow { } /*static*/ Maybe TriuOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); *out->mut_data_type() = in.data_type(); return Maybe::Ok(); } diff --git a/oneflow/user/ops/unfold_tensor_op.cpp b/oneflow/user/ops/unfold_tensor_op.cpp index 04b6c6c8423..73fba45964f 100644 --- a/oneflow/user/ops/unfold_tensor_op.cpp +++ b/oneflow/user/ops/unfold_tensor_op.cpp @@ -86,7 +86,7 @@ namespace oneflow { /*static*/ Maybe UnfoldTensorGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& in = ctx->InputTensorDesc("x", 0); const Shape& in_shape = in.shape(); - user_op::TensorDesc* dx_desc = ctx->OutputTensorDesc("dx", 0); + user_op::TensorDesc* dx_desc = ctx->MutOutputTensorDesc("dx", 0); *dx_desc->mut_shape() = Shape(in_shape.dim_vec()); return Maybe::Ok(); } diff --git a/oneflow/user/ops/unique_with_counts_op.cpp b/oneflow/user/ops/unique_with_counts_op.cpp index ea0c120dfa7..e36b87503d6 100644 --- a/oneflow/user/ops/unique_with_counts_op.cpp +++ b/oneflow/user/ops/unique_with_counts_op.cpp @@ -25,19 +25,19 @@ namespace oneflow { const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); CHECK_EQ_OR_RETURN(x.shape().NumAxes(), 1); - user_op::TensorDesc* y = ctx->OutputTensorDesc("y", 0); + user_op::TensorDesc* y = ctx->MutOutputTensorDesc("y", 0); *y->mut_shape() = x.shape(); *y->mut_is_dynamic() = x.is_dynamic(); - user_op::TensorDesc* idx = ctx->OutputTensorDesc("idx", 0); + user_op::TensorDesc* idx = ctx->MutOutputTensorDesc("idx", 0); *idx->mut_shape() = x.shape(); *idx->mut_is_dynamic() = x.is_dynamic(); - user_op::TensorDesc* count = ctx->OutputTensorDesc("count", 0); + user_op::TensorDesc* count = ctx->MutOutputTensorDesc("count", 0); *count->mut_shape() = x.shape(); *count->mut_is_dynamic() = x.is_dynamic(); - user_op::TensorDesc* num_unique = ctx->OutputTensorDesc("num_unique", 0); + user_op::TensorDesc* num_unique = ctx->MutOutputTensorDesc("num_unique", 0); *num_unique->mut_shape() = Shape({1}); return Maybe::Ok(); } @@ -48,15 +48,15 @@ namespace oneflow { const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); auto out_idx = ctx->Attr("out_idx"); CHECK_OR_RETURN(IsIndexDataType(out_idx)); - user_op::TensorDesc* y = ctx->OutputTensorDesc("y", 0); + user_op::TensorDesc* y = ctx->MutOutputTensorDesc("y", 0); *y->mut_data_type() = x.data_type(); - user_op::TensorDesc* idx = ctx->OutputTensorDesc("idx", 0); + user_op::TensorDesc* idx = ctx->MutOutputTensorDesc("idx", 0); *idx->mut_data_type() = out_idx; - user_op::TensorDesc* count = ctx->OutputTensorDesc("count", 0); + user_op::TensorDesc* count = ctx->MutOutputTensorDesc("count", 0); *count->mut_data_type() = out_idx; - user_op::TensorDesc* num_unique = ctx->OutputTensorDesc("num_unique", 0); + user_op::TensorDesc* num_unique = ctx->MutOutputTensorDesc("num_unique", 0); *num_unique->mut_data_type() = out_idx; return Maybe::Ok(); } diff --git a/oneflow/user/ops/unpack_op.cpp b/oneflow/user/ops/unpack_op.cpp index b0b4ee12f04..47dfb04c932 100644 --- a/oneflow/user/ops/unpack_op.cpp +++ b/oneflow/user/ops/unpack_op.cpp @@ -35,7 +35,7 @@ namespace oneflow { CHECK_GT_OR_RETURN(in_shape.NumAxes(), 0); const auto unpack_num = ctx->Attr("unpack_num"); CHECK_EQ_OR_RETURN(in_shape.At(0) % unpack_num, 0); - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); *out_desc->mut_shape() = in_desc.shape(); out_desc->mut_shape()->Set(0, in_shape.At(0) / unpack_num); *out_desc->mut_is_dynamic() = in_desc.is_dynamic(); @@ -45,7 +45,7 @@ namespace oneflow { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe UnpackOp::InferDataType(user_op::InferContext* ctx) { - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("out", 0); const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); *out_desc->mut_data_type() = in_desc.data_type(); return Maybe::Ok(); diff --git a/oneflow/user/ops/unsorted_batch_segment_sum_op.cpp b/oneflow/user/ops/unsorted_batch_segment_sum_op.cpp index fa6b1ac22c3..b9a56f11845 100644 --- a/oneflow/user/ops/unsorted_batch_segment_sum_op.cpp +++ b/oneflow/user/ops/unsorted_batch_segment_sum_op.cpp @@ -46,7 +46,7 @@ namespace oneflow { CHECK_EQ_OR_RETURN(segment_ids.is_dynamic(), data.is_dynamic()); const int64_t num_segments = ctx->Attr("num_segments"); CHECK_GE_OR_RETURN(num_segments, 1); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); FOR_RANGE(int64_t, i, 0, segment_ids.shape().NumAxes() - 1) { CHECK_EQ_OR_RETURN(segment_ids.shape().At(i), data.shape().At(i)); @@ -64,7 +64,7 @@ namespace oneflow { /*static*/ Maybe UnsortedBatchSegmentSumOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& data = ctx->InputTensorDesc("data", 0); const user_op::TensorDesc& segment_ids = ctx->InputTensorDesc("segment_ids", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + user_op::TensorDesc* out = ctx->MutOutputTensorDesc("out", 0); CHECK_OR_RETURN(IsIndexDataType(segment_ids.data_type())); *out->mut_data_type() = data.data_type(); return Maybe::Ok(); diff --git a/oneflow/user/ops/upsample_op.cpp b/oneflow/user/ops/upsample_op.cpp index f29735fedf6..216cee4bd78 100644 --- a/oneflow/user/ops/upsample_op.cpp +++ b/oneflow/user/ops/upsample_op.cpp @@ -24,7 +24,7 @@ namespace oneflow { } /*static*/ Maybe UpsampleLinear1DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); - user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); + user_op::TensorDesc* y_desc = ctx->MutOutputTensorDesc("y", 0); const double scale_factor = ctx->Attr("scale_factor"); CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" @@ -53,7 +53,7 @@ namespace oneflow { } /*static*/ Maybe UpsampleNearest1DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); - user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); + user_op::TensorDesc* y_desc = ctx->MutOutputTensorDesc("y", 0); const double scale_factor = ctx->Attr("scale_factor"); CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" && x_desc.shape().NumAxes() == 3) @@ -81,7 +81,7 @@ namespace oneflow { } /*static*/ Maybe UpsampleNearest2DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); - user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); + user_op::TensorDesc* y_desc = ctx->MutOutputTensorDesc("y", 0); const double height_scale = ctx->Attr("height_scale"); const double width_scale = ctx->Attr("width_scale"); CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" @@ -112,7 +112,7 @@ namespace oneflow { } /*static*/ Maybe UpsampleBilinear2DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); - user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); + user_op::TensorDesc* y_desc = ctx->MutOutputTensorDesc("y", 0); const double height_scale = ctx->Attr("height_scale"); const double width_scale = ctx->Attr("width_scale"); CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" @@ -143,7 +143,7 @@ namespace oneflow { } /*static*/ Maybe UpsampleBicubic2DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); - user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); + user_op::TensorDesc* y_desc = ctx->MutOutputTensorDesc("y", 0); const double height_scale = ctx->Attr("height_scale"); const double width_scale = ctx->Attr("width_scale"); CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" @@ -174,7 +174,7 @@ namespace oneflow { } /*static*/ Maybe UpsampleNearest3DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); - user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); + user_op::TensorDesc* y_desc = ctx->MutOutputTensorDesc("y", 0); const double depth_scale = ctx->Attr("depth_scale"); const double height_scale = ctx->Attr("height_scale"); const double width_scale = ctx->Attr("width_scale"); @@ -207,7 +207,7 @@ namespace oneflow { } /*static*/ Maybe UpsampleTrilinear3DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); - user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); + user_op::TensorDesc* y_desc = ctx->MutOutputTensorDesc("y", 0); const double depth_scale = ctx->Attr("depth_scale"); const double height_scale = ctx->Attr("height_scale"); const double width_scale = ctx->Attr("width_scale"); From e6ba760ec479b92a35088ad7fd1c830e27c24ca8 Mon Sep 17 00:00:00 2001 From: clackhan Date: Thu, 21 Jul 2022 17:11:12 +0800 Subject: [PATCH 45/67] replce const DataType& with DataType --- oneflow/core/framework/infer_util.h | 6 +++--- oneflow/core/framework/op_expr.cpp | 6 +++--- oneflow/core/kernel/user_kernel.cpp | 6 +++--- oneflow/core/operator/user_op.cpp | 6 +++--- oneflow/user/kernels/stateful_opkernel.cpp | 18 +++++++++--------- 5 files changed, 21 insertions(+), 21 deletions(-) diff --git a/oneflow/core/framework/infer_util.h b/oneflow/core/framework/infer_util.h index d91114fae54..137287e1764 100644 --- a/oneflow/core/framework/infer_util.h +++ b/oneflow/core/framework/infer_util.h @@ -52,10 +52,10 @@ class InferContext { virtual Stride* MutOutputStride(const std::string&, int32_t) = 0; virtual const Stride& Stride4ArgNameAndIndex(const std::string&, int32_t) const = 0; virtual Stride* MutStride4ArgNameAndIndex(const std::string&, int32_t) = 0; - virtual const DataType& InputDType(const std::string&, int32_t) const = 0; - virtual const DataType& OutputDType(const std::string&, int32_t) const = 0; + virtual DataType InputDType(const std::string&, int32_t) const = 0; + virtual DataType OutputDType(const std::string&, int32_t) const = 0; virtual DataType* MutOutputDType(const std::string&, int32_t) = 0; - virtual const DataType& Dtype4ArgNameAndIndex(const std::string&, int32_t) const = 0; + virtual DataType Dtype4ArgNameAndIndex(const std::string&, int32_t) const = 0; virtual DataType* MutDtype4ArgNameAndIndex(const std::string&, int32_t) = 0; virtual const std::vector>& inputs() const = 0; virtual const std::vector>& outputs() const = 0; diff --git a/oneflow/core/framework/op_expr.cpp b/oneflow/core/framework/op_expr.cpp index f4e5b3ee871..8b41dee7478 100644 --- a/oneflow/core/framework/op_expr.cpp +++ b/oneflow/core/framework/op_expr.cpp @@ -276,16 +276,16 @@ class UserOpExprInferContext : public user_op::InferContext { return TensorDesc4ArgNameAndIndex(arg_name, index)->mut_stride(); } - const DataType& InputDType(const std::string& arg_name, int32_t index) const override { + DataType InputDType(const std::string& arg_name, int32_t index) const override { return Dtype4ArgNameAndIndex(arg_name, index); } - const DataType& OutputDType(const std::string& arg_name, int32_t index) const override { + DataType OutputDType(const std::string& arg_name, int32_t index) const override { return Dtype4ArgNameAndIndex(arg_name, index); } DataType* MutOutputDType(const std::string& arg_name, int32_t index) override { return MutDtype4ArgNameAndIndex(arg_name, index); } - const DataType& Dtype4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { + DataType Dtype4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { return const_cast(this) ->TensorDesc4ArgNameAndIndex(arg_name, index) ->data_type(); diff --git a/oneflow/core/kernel/user_kernel.cpp b/oneflow/core/kernel/user_kernel.cpp index 06432f6ba26..e4f2616598a 100644 --- a/oneflow/core/kernel/user_kernel.cpp +++ b/oneflow/core/kernel/user_kernel.cpp @@ -294,16 +294,16 @@ class UserKernelOpInferContext : public user_op::InferContext { Stride* MutStride4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { return TensorDesc4ArgNameAndIndex(arg_name, index)->mut_stride(); } - const DataType& InputDType(const std::string& arg_name, int32_t index) const override { + DataType InputDType(const std::string& arg_name, int32_t index) const override { return Dtype4ArgNameAndIndex(arg_name, index); } - const DataType& OutputDType(const std::string& arg_name, int32_t index) const override { + DataType OutputDType(const std::string& arg_name, int32_t index) const override { return Dtype4ArgNameAndIndex(arg_name, index); } DataType* MutOutputDType(const std::string& arg_name, int32_t index) override { return MutDtype4ArgNameAndIndex(arg_name, index); } - const DataType& Dtype4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { + DataType Dtype4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { return const_cast(this) ->TensorDesc4ArgNameAndIndex(arg_name, index) ->data_type(); diff --git a/oneflow/core/operator/user_op.cpp b/oneflow/core/operator/user_op.cpp index bb94bfe86e6..f58fcebdf8c 100644 --- a/oneflow/core/operator/user_op.cpp +++ b/oneflow/core/operator/user_op.cpp @@ -214,16 +214,16 @@ class UserOpInferContext final : public user_op::InferContext { if (it == arg2tensor_desc_.end()) { return nullptr; }; return it->second.mut_stride(); } - const DataType& InputDType(const std::string& arg_name, int32_t index) const override { + DataType InputDType(const std::string& arg_name, int32_t index) const override { return Dtype4ArgNameAndIndex(arg_name, index); } - const DataType& OutputDType(const std::string& arg_name, int32_t index) const override { + DataType OutputDType(const std::string& arg_name, int32_t index) const override { return Dtype4ArgNameAndIndex(arg_name, index); } DataType* MutOutputDType(const std::string& arg_name, int32_t index) override { return MutDtype4ArgNameAndIndex(arg_name, index); } - const DataType& Dtype4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { + DataType Dtype4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index)); if (it == arg2tensor_desc_.end()) { return DataType::kInvalidDataType; }; return it->second.data_type(); diff --git a/oneflow/user/kernels/stateful_opkernel.cpp b/oneflow/user/kernels/stateful_opkernel.cpp index a7c47107e23..f37f8771aec 100644 --- a/oneflow/user/kernels/stateful_opkernel.cpp +++ b/oneflow/user/kernels/stateful_opkernel.cpp @@ -212,20 +212,20 @@ class UserOpInferContextHelper final { int32_t index) const { return NonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index)->mut_stride(); } - const DataType& InputDType(eager::CallContext* call_ctx, const std::string& arg_name, - int32_t index) const { + DataType InputDType(eager::CallContext* call_ctx, const std::string& arg_name, + int32_t index) const { return Dtype4ArgNameAndIndex(call_ctx, arg_name, index); } - const DataType& OutputDType(eager::CallContext* call_ctx, const std::string& arg_name, - int32_t index) const { + DataType OutputDType(eager::CallContext* call_ctx, const std::string& arg_name, + int32_t index) const { return Dtype4ArgNameAndIndex(call_ctx, arg_name, index); } DataType* MutOutputDType(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { return MutDtype4ArgNameAndIndex(call_ctx, arg_name, index); } - const DataType& Dtype4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, - int32_t index) const { + DataType Dtype4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, + int32_t index) const { return NonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index)->data_type(); } DataType* MutDtype4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, @@ -376,16 +376,16 @@ class UserOpInferContext : public user_op::InferContext { Stride* MutStride4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { return helper_->MutStride4ArgNameAndIndex(call_ctx_, arg_name, index); } - const DataType& InputDType(const std::string& arg_name, int32_t index) const override { + DataType InputDType(const std::string& arg_name, int32_t index) const override { return helper_->InputDType(call_ctx_, arg_name, index); } - const DataType& OutputDType(const std::string& arg_name, int32_t index) const override { + DataType OutputDType(const std::string& arg_name, int32_t index) const override { return helper_->OutputDType(call_ctx_, arg_name, index); } DataType* MutOutputDType(const std::string& arg_name, int32_t index) override { return helper_->MutOutputDType(call_ctx_, arg_name, index); } - const DataType& Dtype4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { + DataType Dtype4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { return helper_->Dtype4ArgNameAndIndex(call_ctx_, arg_name, index); } DataType* MutDtype4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { From 86c5491af0ae49e9c3314ffbdf57972a143b5091 Mon Sep 17 00:00:00 2001 From: clackhan Date: Thu, 21 Jul 2022 17:44:51 +0800 Subject: [PATCH 46/67] split const and mut func in LocalTensorMeta --- oneflow/core/common/tensor_meta.h | 48 ++++++++++++++++++++++--------- 1 file changed, 35 insertions(+), 13 deletions(-) diff --git a/oneflow/core/common/tensor_meta.h b/oneflow/core/common/tensor_meta.h index 368bd34fbdf..26d086cf68b 100644 --- a/oneflow/core/common/tensor_meta.h +++ b/oneflow/core/common/tensor_meta.h @@ -60,15 +60,30 @@ class TensorMeta : public user_op::TensorDesc { bool is_dynamic() const override { return is_dynamic_; } bool is_contiguous() const { return IsContiguous(shape(), *stride_); } - void set_shape(const std::shared_ptr& val) { shape_ = val; } - Shape* mut_shape() override { return const_cast(shape_.get()); } - void set_stride(const std::shared_ptr& val) { stride_ = val; } - Stride* mut_stride() override { return const_cast(stride_.get()); } - DataType* mut_dtype() { return &data_type_; } - void set_dtype(DataType data_type) { data_type_ = data_type; } - DataType* mut_data_type() override { return &data_type_; } - bool* mut_is_dynamic() override { return &is_dynamic_; } - void set_is_dynamic(bool val) override { is_dynamic_ = val; } + virtual Shape* mut_shape() override { + PRINT_BUG_PROMPT_AND_ABORT(); + return nullptr; + } + virtual Stride* mut_stride() override { + PRINT_BUG_PROMPT_AND_ABORT(); + return nullptr; + } + virtual DataType* mut_data_type() override { + PRINT_BUG_PROMPT_AND_ABORT(); + return nullptr; + } + virtual bool* mut_is_dynamic() override { + PRINT_BUG_PROMPT_AND_ABORT(); + return nullptr; + } + virtual void set_is_dynamic(bool val) override { PRINT_BUG_PROMPT_AND_ABORT(); } + + virtual void set_shape(const std::shared_ptr& val) { PRINT_BUG_PROMPT_AND_ABORT(); } + virtual void set_stride(const std::shared_ptr& val) { + PRINT_BUG_PROMPT_AND_ABORT(); + } + virtual DataType* mut_dtype() { PRINT_BUG_PROMPT_AND_ABORT(); } + virtual void set_dtype(DataType data_type) { PRINT_BUG_PROMPT_AND_ABORT(); } protected: TensorMeta& operator=(const TensorMeta& other) { @@ -79,7 +94,6 @@ class TensorMeta : public user_op::TensorDesc { return *this; } - private: std::shared_ptr shape_; std::shared_ptr stride_; DataType data_type_; @@ -100,9 +114,6 @@ class LocalTensorMeta : public TensorMeta { const Symbol& device() const { return device_; } int64_t storage_offset() const { return storage_offset_; } - Symbol* mut_device() { return &device_; } - void set_storage_offset(int64_t offset) { storage_offset_ = offset; } - bool operator==(const LocalTensorMeta& other) const; size_t CalcHashValue() const; @@ -131,6 +142,17 @@ class MutLocalTensorMeta : public TensorMeta { Symbol* mut_device() { return &device_; } void set_storage_offset(int64_t offset) { storage_offset_ = offset; } + Shape* mut_shape() override { return const_cast(shape_.get()); } + Stride* mut_stride() override { return const_cast(stride_.get()); } + DataType* mut_data_type() override { return &data_type_; } + bool* mut_is_dynamic() override { return &is_dynamic_; } + void set_is_dynamic(bool val) override { is_dynamic_ = val; } + + void set_shape(const std::shared_ptr& val) override { shape_ = val; } + void set_stride(const std::shared_ptr& val) override { stride_ = val; } + DataType* mut_dtype() override { return &data_type_; } + void set_dtype(DataType data_type) override { data_type_ = data_type; } + bool operator==(const MutLocalTensorMeta& other) const; size_t CalcHashValue() const; From 52046b6e037c6c344db668f868b2e3a03afc2bc6 Mon Sep 17 00:00:00 2001 From: clackhan Date: Thu, 21 Jul 2022 18:38:24 +0800 Subject: [PATCH 47/67] replace const DataType& with DataType ret --- oneflow/user/kernels/arg_where_kernel.cpp | 2 +- oneflow/user/kernels/broadcast_div_grad_kernel.cpp | 2 +- oneflow/user/kernels/broadcast_pow_grad_kernel.cpp | 4 ++-- oneflow/user/kernels/broadcast_pow_grad_kernel.cu | 2 +- ...elf_attention_query_mul_key_and_value_kernel.cu | 2 +- oneflow/user/ops/cast_op.cpp | 2 +- oneflow/user/ops/categorical_ordinal_encode_op.cpp | 2 +- .../user/ops/fused_dot_feature_interaction_op.cpp | 4 ++-- oneflow/user/ops/fused_gru_cell_op.cpp | 4 ++-- oneflow/user/ops/fused_lstm_cell_op.cpp | 4 ++-- ..._self_attention_query_mul_key_and_value_ops.cpp | 4 ++-- oneflow/user/ops/layer_norm_op.cpp | 2 +- oneflow/user/ops/masked_fill_op.cpp | 2 +- oneflow/user/ops/matmul_op.cpp | 2 +- oneflow/user/ops/matrix_vector_product_op.cpp | 4 ++-- oneflow/user/ops/multi_reduce_ops.cpp | 2 +- oneflow/user/ops/nll_op.cpp | 6 +++--- oneflow/user/ops/relu_op.cpp | 2 +- oneflow/user/ops/vector_matrix_product_op.cpp | 4 ++-- oneflow/user/ops/where_op.cpp | 14 +++++++------- 20 files changed, 35 insertions(+), 35 deletions(-) diff --git a/oneflow/user/kernels/arg_where_kernel.cpp b/oneflow/user/kernels/arg_where_kernel.cpp index 6413530089c..51c2f78a811 100644 --- a/oneflow/user/kernels/arg_where_kernel.cpp +++ b/oneflow/user/kernels/arg_where_kernel.cpp @@ -75,7 +75,7 @@ template size_t InferTempStorageBytesSize(user_op::InferContext* ctx) { const Shape& input_shape = ctx->InputShape("input", 0); if (input_shape.NumAxes() == 0) { return 0; } - const DataType& input_dtype = ctx->InputDType("input", 0); + DataType input_dtype = ctx->InputDType("input", 0); DataType output_dtype = ctx->OutputDType("output", 0); return SwitchUtil::SwitchGetWorkspaceBytesSize( SwitchCase(device_type, input_dtype, output_dtype, input_shape.NumAxes()), diff --git a/oneflow/user/kernels/broadcast_div_grad_kernel.cpp b/oneflow/user/kernels/broadcast_div_grad_kernel.cpp index 7a786212989..d729573821f 100644 --- a/oneflow/user/kernels/broadcast_div_grad_kernel.cpp +++ b/oneflow/user/kernels/broadcast_div_grad_kernel.cpp @@ -65,7 +65,7 @@ class BroadcastDivGradKernel final : public user_op::OpKernel { && (user_op::HobDataType("y", 0) == OF_PP_PAIR_SECOND(dtype_pair))) \ .SetInferTmpSizeFn([](oneflow::user_op::InferContext* ctx) { \ const user_op::TensorDesc& z = ctx->InputTensorDesc("z", 0); \ - const DataType& data_type = z.data_type(); \ + DataType data_type = z.data_type(); \ const int64_t elem_cnt = z.shape().elem_cnt(); \ return GetCudaAlignedSize(elem_cnt * GetSizeOfDataType(data_type)); \ }); diff --git a/oneflow/user/kernels/broadcast_pow_grad_kernel.cpp b/oneflow/user/kernels/broadcast_pow_grad_kernel.cpp index c4cf0570935..a1b06f00034 100644 --- a/oneflow/user/kernels/broadcast_pow_grad_kernel.cpp +++ b/oneflow/user/kernels/broadcast_pow_grad_kernel.cpp @@ -100,7 +100,7 @@ class BroadcastPowYGradKernel final : public user_op::OpKernel { && (user_op::HobDataType("x", 0) == OF_PP_PAIR_SECOND(dtype_pair))) \ .SetInferTmpSizeFn([](oneflow::user_op::InferContext* ctx) { \ const user_op::TensorDesc& z = ctx->InputTensorDesc("z", 0); \ - const DataType& data_type = z.data_type(); \ + DataType data_type = z.data_type(); \ const int64_t elem_cnt = z.shape().elem_cnt(); \ return GetCudaAlignedSize(elem_cnt * GetSizeOfDataType(data_type)); \ }); @@ -112,7 +112,7 @@ class BroadcastPowYGradKernel final : public user_op::OpKernel { && (user_op::HobDataType("x", 0) == OF_PP_PAIR_SECOND(dtype_pair))) \ .SetInferTmpSizeFn([](oneflow::user_op::InferContext* ctx) { \ const user_op::TensorDesc& z = ctx->InputTensorDesc("z", 0); \ - const DataType& data_type = z.data_type(); \ + DataType data_type = z.data_type(); \ const int64_t elem_cnt = z.shape().elem_cnt(); \ return GetCudaAlignedSize(elem_cnt * GetSizeOfDataType(data_type)); \ }); diff --git a/oneflow/user/kernels/broadcast_pow_grad_kernel.cu b/oneflow/user/kernels/broadcast_pow_grad_kernel.cu index 1471f2383c4..3bd84c9ba95 100644 --- a/oneflow/user/kernels/broadcast_pow_grad_kernel.cu +++ b/oneflow/user/kernels/broadcast_pow_grad_kernel.cu @@ -77,7 +77,7 @@ class BroadcastPowYGradKernel final : public user_op::OpKernel { && (user_op::HobDataType("x", 0) == OF_PP_PAIR_SECOND(dtype_pair))) \ .SetInferTmpSizeFn([](oneflow::user_op::InferContext* ctx) { \ const user_op::TensorDesc& z = ctx->InputTensorDesc("z", 0); \ - const DataType& data_type = z.data_type(); \ + DataType data_type = z.data_type(); \ const int64_t elem_cnt = z.shape().elem_cnt(); \ return GetCudaAlignedSize(elem_cnt * GetSizeOfDataType(data_type)); \ }); diff --git a/oneflow/user/kernels/fused_self_attention_query_mul_key_and_value_kernel.cu b/oneflow/user/kernels/fused_self_attention_query_mul_key_and_value_kernel.cu index b246c689cbe..01d25ccd375 100644 --- a/oneflow/user/kernels/fused_self_attention_query_mul_key_and_value_kernel.cu +++ b/oneflow/user/kernels/fused_self_attention_query_mul_key_and_value_kernel.cu @@ -273,7 +273,7 @@ size_t InferTmpBufferSize(user_op::InferContext* ctx) { size_t InferGradTmpBufferSize(user_op::InferContext* ctx) { const Shape& value_shape = ctx->InputShape("value_grad", 0); - const DataType& value_dtype = ctx->InputDType("value_grad", 0); + DataType value_dtype = ctx->InputDType("value_grad", 0); return value_shape.elem_cnt() * GetSizeOfDataType(value_dtype); } diff --git a/oneflow/user/ops/cast_op.cpp b/oneflow/user/ops/cast_op.cpp index 0cbcd03ce5f..d816f3f7a80 100644 --- a/oneflow/user/ops/cast_op.cpp +++ b/oneflow/user/ops/cast_op.cpp @@ -79,7 +79,7 @@ REGISTER_USER_OP_GRAD("cast").SetGenBackwardOpConfFn([](const user_op::UserOpWra user_op::AddOpFn AddOp) -> Maybe { if (op.NeedGenGradTensor4OpInput("in", 0)) { user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad"); - const DataType& dtype = op.TensorDesc4ArgNameAndIndex("in", 0).data_type(); + DataType dtype = op.TensorDesc4ArgNameAndIndex("in", 0).data_type(); user_op::UserOpConfWrapper cast_grad_op = builder.Op("cast") .Input("in", op.GetGradTensorWithOpOutput("out", 0)) diff --git a/oneflow/user/ops/categorical_ordinal_encode_op.cpp b/oneflow/user/ops/categorical_ordinal_encode_op.cpp index 3aca99d4a32..59deaeb748c 100644 --- a/oneflow/user/ops/categorical_ordinal_encode_op.cpp +++ b/oneflow/user/ops/categorical_ordinal_encode_op.cpp @@ -68,7 +68,7 @@ namespace oneflow { } /* static */ Maybe CategoricalOrdinalEncodeOp::InferDataType(user_op::InferContext* ctx) { - const DataType& data_type = ctx->InputDType("in", 0); + DataType data_type = ctx->InputDType("in", 0); CHECK_OR_RETURN(IsIndexDataType(data_type)); CHECK_EQ_OR_RETURN(ctx->InputDType("table", 0), data_type); CHECK_EQ_OR_RETURN(ctx->InputDType("size", 0), data_type); diff --git a/oneflow/user/ops/fused_dot_feature_interaction_op.cpp b/oneflow/user/ops/fused_dot_feature_interaction_op.cpp index 72e4c35fb55..656e4d31a3a 100644 --- a/oneflow/user/ops/fused_dot_feature_interaction_op.cpp +++ b/oneflow/user/ops/fused_dot_feature_interaction_op.cpp @@ -87,7 +87,7 @@ namespace oneflow { /* static */ Maybe FusedDotFeatureInteractionOp::InferDataType(user_op::InferContext* ctx) { const int64_t feature_input_size = ctx->input_size("features"); CHECK_GE_OR_RETURN(feature_input_size, 1); - const auto& first_feature_dtype = ctx->InputDType("features", 0); + DataType first_feature_dtype = ctx->InputDType("features", 0); for (int64_t i = 1; i < feature_input_size; ++i) { CHECK_EQ_OR_RETURN(first_feature_dtype, ctx->InputDType("features", i)); } @@ -137,7 +137,7 @@ namespace oneflow { /* static */ Maybe FusedDotFeatureInteractionGradOp::InferDataType( user_op::InferContext* ctx) { - const auto& dy_dtype = ctx->InputDType("dy", 0); + DataType dy_dtype = ctx->InputDType("dy", 0); for (int64_t i = 0; i < ctx->output_size("features_grad"); ++i) { *ctx->MutOutputDType("features_grad", i) = dy_dtype; } diff --git a/oneflow/user/ops/fused_gru_cell_op.cpp b/oneflow/user/ops/fused_gru_cell_op.cpp index 5317492bedf..62d4ffa3538 100644 --- a/oneflow/user/ops/fused_gru_cell_op.cpp +++ b/oneflow/user/ops/fused_gru_cell_op.cpp @@ -60,7 +60,7 @@ namespace oneflow { } /* static */ Maybe FusedGruCellOp::InferDataType(user_op::InferContext* ctx) { - const oneflow::DataType& in_types = ctx->InputDType("hx", 0); + DataType in_types = ctx->InputDType("hx", 0); *ctx->MutOutputDType("hy", 0) = in_types; *ctx->MutOutputDType("workspace", 0) = in_types; return Maybe::Ok(); @@ -117,7 +117,7 @@ namespace oneflow { } /* static */ Maybe FusedGruCellGradOp ::InferDataType(user_op::InferContext* ctx) { - const oneflow::DataType& in_types = ctx->InputDType("grad_hy", 0); + DataType in_types = ctx->InputDType("grad_hy", 0); *ctx->MutOutputDType("grad_input_gates", 0) = in_types; *ctx->MutOutputDType("grad_hidden_gates", 0) = in_types; if (ctx->has_output("grad_hx", 0)) { *ctx->MutOutputDType("grad_hx", 0) = in_types; } diff --git a/oneflow/user/ops/fused_lstm_cell_op.cpp b/oneflow/user/ops/fused_lstm_cell_op.cpp index 292dfbfd0dc..aa8179ba374 100644 --- a/oneflow/user/ops/fused_lstm_cell_op.cpp +++ b/oneflow/user/ops/fused_lstm_cell_op.cpp @@ -63,7 +63,7 @@ namespace oneflow { } /* static */ Maybe FusedLstmCellOp::InferDataType(user_op::InferContext* ctx) { - const oneflow::DataType& in_types = ctx->InputDType("cx", 0); + DataType in_types = ctx->InputDType("cx", 0); *ctx->MutOutputDType("hy", 0) = in_types; *ctx->MutOutputDType("cy", 0) = in_types; *ctx->MutOutputDType("workspace", 0) = in_types; @@ -119,7 +119,7 @@ namespace oneflow { } /* static */ Maybe FusedLstmCellGradOp::InferDataType(user_op::InferContext* ctx) { - const oneflow::DataType& in_types = ctx->InputDType("grad_hy", 0); + DataType in_types = ctx->InputDType("grad_hy", 0); *ctx->MutOutputDType("grad_gates", 0) = in_types; if (ctx->has_output("grad_cx", 0)) { *ctx->MutOutputDType("grad_cx", 0) = in_types; } if (ctx->has_output("grad_bias", 0)) { *ctx->MutOutputDType("grad_bias", 0) = in_types; } diff --git a/oneflow/user/ops/fused_self_attention_query_mul_key_and_value_ops.cpp b/oneflow/user/ops/fused_self_attention_query_mul_key_and_value_ops.cpp index 0fcdb49ae5a..a96d376df63 100644 --- a/oneflow/user/ops/fused_self_attention_query_mul_key_and_value_ops.cpp +++ b/oneflow/user/ops/fused_self_attention_query_mul_key_and_value_ops.cpp @@ -20,7 +20,7 @@ namespace oneflow { /*static*/ auto FusedSelfAttentionQueryMulKeyAndValueOp::InferDataType(user_op::InferContext* ctx) -> Maybe { - const DataType& dtype = ctx->InputDType("hidden_states", 0); + DataType dtype = ctx->InputDType("hidden_states", 0); *ctx->MutOutputDType("query_mul_key", 0) = dtype; *ctx->MutOutputDType("value", 0) = dtype; return Maybe::Ok(); @@ -67,7 +67,7 @@ namespace oneflow { /*static*/ auto FusedSelfAttentionQueryMulKeyAndValueGradOp::InferDataType( user_op::InferContext* ctx) -> Maybe { - const DataType& dtype = ctx->InputDType("query_mul_key_grad", 0); + DataType dtype = ctx->InputDType("query_mul_key_grad", 0); CHECK_EQ_OR_RETURN(ctx->InputDType("value_grad", 0), dtype); *ctx->MutOutputDType("hidden_states_grad", 0) = dtype; return Maybe::Ok(); diff --git a/oneflow/user/ops/layer_norm_op.cpp b/oneflow/user/ops/layer_norm_op.cpp index 3ae2765b362..09dd2a871ad 100644 --- a/oneflow/user/ops/layer_norm_op.cpp +++ b/oneflow/user/ops/layer_norm_op.cpp @@ -164,7 +164,7 @@ oneflow::DataType InferBnParamDataType(const DataType x_data_type) { CHECK_EQ_OR_RETURN(dy.data_type(), x.data_type()); const user_op::TensorDesc& mean = ctx->InputTensorDesc("mean", 0); const user_op::TensorDesc& inv_variance = ctx->InputTensorDesc("inv_variance", 0); - const DataType& bn_param_data_type = InferBnParamDataType(x.data_type()); + DataType bn_param_data_type = InferBnParamDataType(x.data_type()); CHECK_EQ_OR_RETURN(mean.data_type(), bn_param_data_type); CHECK_EQ_OR_RETURN(inv_variance.data_type(), bn_param_data_type); user_op::TensorDesc* dx = ctx->OutputTensorDesc("dx", 0); diff --git a/oneflow/user/ops/masked_fill_op.cpp b/oneflow/user/ops/masked_fill_op.cpp index 0f982398245..d4a21990d75 100644 --- a/oneflow/user/ops/masked_fill_op.cpp +++ b/oneflow/user/ops/masked_fill_op.cpp @@ -27,7 +27,7 @@ Maybe InferMaskedFillTensorDesc(user_op::InferContext* ctx) { } Maybe InferMaskedFillDataType(user_op::InferContext* ctx) { - const DataType& mask_dtype = ctx->InputDType("mask", 0); + DataType mask_dtype = ctx->InputDType("mask", 0); CHECK_OR_RETURN(IsIntegralDataType(mask_dtype) || IsBoolDataType(mask_dtype)); *ctx->MutOutputDType("out", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); diff --git a/oneflow/user/ops/matmul_op.cpp b/oneflow/user/ops/matmul_op.cpp index 0b650ef622d..a7018998980 100644 --- a/oneflow/user/ops/matmul_op.cpp +++ b/oneflow/user/ops/matmul_op.cpp @@ -64,7 +64,7 @@ Maybe InferTensorDesc4Matmul(user_op::InferContext* ctx) { } Maybe InferDataType4Matmul(user_op::InferContext* ctx) { - const DataType& dtype = ctx->InputDType("a", 0); + DataType dtype = ctx->InputDType("a", 0); CHECK_EQ_OR_RETURN(ctx->InputDType("b", 0), dtype); if (ctx->has_input("_add_to_output", 0)) { CHECK_EQ_OR_RETURN(ctx->InputDType("_add_to_output", 0), dtype); diff --git a/oneflow/user/ops/matrix_vector_product_op.cpp b/oneflow/user/ops/matrix_vector_product_op.cpp index c2a7e8fd3f6..65c1c5e3be9 100644 --- a/oneflow/user/ops/matrix_vector_product_op.cpp +++ b/oneflow/user/ops/matrix_vector_product_op.cpp @@ -31,7 +31,7 @@ Maybe InferTensorDesc4MatrixVectorProduct(user_op::InferContext* ctx) { } Maybe InferDataType4MatrixVectorProduct(user_op::InferContext* ctx) { - const DataType& dtype = ctx->InputDType("a", 0); + DataType dtype = ctx->InputDType("a", 0); CHECK_EQ_OR_RETURN(ctx->InputDType("b", 0), dtype) << "Matrix A datatype should be equal to Vector B. "; *ctx->MutOutputDType("out", 0) = dtype; @@ -63,7 +63,7 @@ Maybe InferTensorDesc4MatrixVectorProductGradB(user_op::InferContext* ctx) } Maybe InferDataType4Grad(user_op::InferContext* ctx) { - const DataType& dtype = ctx->InputDType("dy", 0); + DataType dtype = ctx->InputDType("dy", 0); *ctx->MutOutputDType("dx", 0) = dtype; return Maybe::Ok(); } diff --git a/oneflow/user/ops/multi_reduce_ops.cpp b/oneflow/user/ops/multi_reduce_ops.cpp index 04702f80de2..205b312db08 100644 --- a/oneflow/user/ops/multi_reduce_ops.cpp +++ b/oneflow/user/ops/multi_reduce_ops.cpp @@ -28,7 +28,7 @@ Maybe InferMultiReduceOpShape(user_op::InferContext* ctx) { } Maybe InferMultiReduceOpDataType(user_op::InferContext* ctx) { - const auto& x_0_dtype = ctx->InputDType("x", 0); + DataType x_0_dtype = ctx->InputDType("x", 0); for (size_t i = 1; i < ctx->input_size("x"); ++i) { CHECK_EQ_OR_RETURN(ctx->InputDType("x", i), x_0_dtype) << ctx->op_name() << ": the " << i << " th input has the different data type with others"; diff --git a/oneflow/user/ops/nll_op.cpp b/oneflow/user/ops/nll_op.cpp index b7a56773e9f..65301d14f25 100644 --- a/oneflow/user/ops/nll_op.cpp +++ b/oneflow/user/ops/nll_op.cpp @@ -22,9 +22,9 @@ namespace oneflow { CHECK_OR_RETURN(IsIndexDataType(ctx->InputDType("target", 0))) << ctx->op_name() << ": expected target being integer type"; - auto input_dtype = ctx->InputDType("input", 0); + DataType input_dtype = ctx->InputDType("input", 0); if (ctx->has_input("weight", 0)) { - auto weight_dtype = ctx->InputDType("weight", 0); + DataType weight_dtype = ctx->InputDType("weight", 0); CHECK_EQ_OR_RETURN(weight_dtype, input_dtype) << ctx->op_name() << ": expected weight dtype " << input_dtype << ", but got " << weight_dtype; } @@ -115,7 +115,7 @@ namespace oneflow { CHECK_OR_RETURN(IsIndexDataType(ctx->InputDType("target", 0))) << ctx->op_name() << ": expected target being integer type"; - auto input_dtype = ctx->InputDType("input", 0); + DataType input_dtype = ctx->InputDType("input", 0); CHECK_EQ_OR_RETURN(ctx->InputDType("out_grad", 0), input_dtype) << ctx->op_name() << ": expected out_grad dtype " << input_dtype << ", got " << ctx->InputDType("out_grad", 0); diff --git a/oneflow/user/ops/relu_op.cpp b/oneflow/user/ops/relu_op.cpp index ee7ad9fd349..afeecd58b70 100644 --- a/oneflow/user/ops/relu_op.cpp +++ b/oneflow/user/ops/relu_op.cpp @@ -63,7 +63,7 @@ namespace oneflow { return InferLogicalTensorDesc(ctx); } /*static*/ Maybe ReluGradOp::InferDataType(user_op::InferContext* ctx) { - const DataType& data_type = ctx->InputDType("y", 0); + DataType data_type = ctx->InputDType("y", 0); CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), data_type) << Error::TypeError() << "Tensors dy and y must have the same type"; *ctx->MutOutputDType("dx", 0) = data_type; diff --git a/oneflow/user/ops/vector_matrix_product_op.cpp b/oneflow/user/ops/vector_matrix_product_op.cpp index 5642da7a8d1..8204e892655 100644 --- a/oneflow/user/ops/vector_matrix_product_op.cpp +++ b/oneflow/user/ops/vector_matrix_product_op.cpp @@ -31,7 +31,7 @@ Maybe InferTensorDesc4VectorMatrixProduct(user_op::InferContext* ctx) { } Maybe InferDataType4VectorMatrixProduct(user_op::InferContext* ctx) { - const DataType& dtype = ctx->InputDType("a", 0); + DataType dtype = ctx->InputDType("a", 0); CHECK_EQ_OR_RETURN(ctx->InputDType("b", 0), dtype) << "Matrix A datatype should be equal to Vector B. "; *ctx->MutOutputDType("out", 0) = dtype; @@ -63,7 +63,7 @@ Maybe InferTensorDesc4VectorMatrixProductGradB(user_op::InferContext* ctx) } Maybe InferDataType4Grad(user_op::InferContext* ctx) { - const DataType& dtype = ctx->InputDType("dy", 0); + DataType dtype = ctx->InputDType("dy", 0); *ctx->MutOutputDType("dx", 0) = dtype; return Maybe::Ok(); } diff --git a/oneflow/user/ops/where_op.cpp b/oneflow/user/ops/where_op.cpp index c36d285c76e..29d6ea63ce5 100644 --- a/oneflow/user/ops/where_op.cpp +++ b/oneflow/user/ops/where_op.cpp @@ -209,9 +209,9 @@ Maybe GetWhereInputArgModify(const GetInputArgModifier& GetInputArgModifie return InferLogicalTensorDesc(ctx); } /*static*/ Maybe WhereOp::InferDataType(user_op::InferContext* ctx) { - const DataType& cond_dtype = ctx->InputDType("condition", 0); + DataType cond_dtype = ctx->InputDType("condition", 0); CHECK_OR_RETURN(IsBoolDataType(cond_dtype) || IsIntegralDataType(cond_dtype)); - const DataType& x_dtype = ctx->InputDType("x", 0); + DataType x_dtype = ctx->InputDType("x", 0); CHECK_EQ_OR_RETURN(x_dtype, ctx->InputDType("y", 0)); *ctx->MutOutputDType("out", 0) = x_dtype; return Maybe::Ok(); @@ -231,9 +231,9 @@ Maybe GetWhereInputArgModify(const GetInputArgModifier& GetInputArgModifie return InferLogicalTensorDesc(ctx); } /*static*/ Maybe WhereScalarXOp::InferDataType(user_op::InferContext* ctx) { - const DataType& cond_dtype = ctx->InputDType("condition", 0); + DataType cond_dtype = ctx->InputDType("condition", 0); CHECK_OR_RETURN(IsBoolDataType(cond_dtype) || IsIntegralDataType(cond_dtype)); - const DataType& y_dtype = ctx->InputDType("y", 0); + DataType y_dtype = ctx->InputDType("y", 0); if (ctx->Attr("has_int_operand")) { CHECK_EQ_OR_RETURN(y_dtype, GetDataType::value) << "expected scalar type " << GetDataType::value << "but found " << y_dtype; @@ -262,9 +262,9 @@ Maybe GetWhereInputArgModify(const GetInputArgModifier& GetInputArgModifie return InferLogicalTensorDesc(ctx); } /*static*/ Maybe WhereScalarYOp::InferDataType(user_op::InferContext* ctx) { - const DataType& cond_dtype = ctx->InputDType("condition", 0); + DataType cond_dtype = ctx->InputDType("condition", 0); CHECK_OR_RETURN(IsBoolDataType(cond_dtype) || IsIntegralDataType(cond_dtype)); - const DataType& x_dtype = ctx->InputDType("x", 0); + DataType x_dtype = ctx->InputDType("x", 0); if (ctx->Attr("has_int_operand")) { CHECK_EQ_OR_RETURN(x_dtype, GetDataType::value) << "expected scalar type " << GetDataType::value << "but found " << x_dtype; @@ -293,7 +293,7 @@ Maybe GetWhereInputArgModify(const GetInputArgModifier& GetInputArgModifie return InferLogicalTensorDesc(ctx); } /*static*/ Maybe WhereScalarXyOp::InferDataType(user_op::InferContext* ctx) { - const DataType& cond_dtype = ctx->InputDType("condition", 0); + DataType cond_dtype = ctx->InputDType("condition", 0); CHECK_OR_RETURN(IsBoolDataType(cond_dtype) || IsIntegralDataType(cond_dtype)); if (ctx->Attr("has_x_bool_operand") && ctx->Attr("has_y_bool_operand")) { *ctx->MutOutputDType("out", 0) = GetDataType::value; From cca9465de3b232b5e1d491e7808e6eda3a600d65 Mon Sep 17 00:00:00 2001 From: clackhan Date: Thu, 21 Jul 2022 19:44:33 +0800 Subject: [PATCH 48/67] split TensorDesc4ArgNameAndIndex and MutTensorDesc4ArgNameAndIndex --- oneflow/core/framework/op_expr.cpp | 62 ++++++++------ oneflow/core/kernel/user_kernel.cpp | 40 ++++----- oneflow/core/operator/user_op.cpp | 15 +++- oneflow/user/kernels/stateful_opkernel.cpp | 97 ++++++++++++++++------ 4 files changed, 137 insertions(+), 77 deletions(-) diff --git a/oneflow/core/framework/op_expr.cpp b/oneflow/core/framework/op_expr.cpp index 8b41dee7478..cf07de52b32 100644 --- a/oneflow/core/framework/op_expr.cpp +++ b/oneflow/core/framework/op_expr.cpp @@ -191,14 +191,32 @@ class UserOpExprInferContext : public user_op::InferContext { const user_op::TensorDesc& InputTensorDesc(const std::string& arg_name, int32_t index) const override { - return *const_cast(this)->TensorDesc4ArgNameAndIndex(arg_name, index); + return TensorDesc4ArgNameAndIndex(arg_name, index); } user_op::TensorDesc* OutputTensorDesc(const std::string& name, int32_t index) override { - return TensorDesc4ArgNameAndIndex(name, index); + return MutTensorDesc4ArgNameAndIndex(name, index); } - user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& name, int32_t index) { + const user_op::TensorDesc& TensorDesc4ArgNameAndIndex(const std::string& name, + int32_t index) const { + { + const auto& arg_tuple = *user_op_expr_->output_arg_tuple(); + int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); + if (tuple_index >= 0) { return *tensor_meta4output_index_(tuple_index); } + } + { + const auto& arg_tuple = *user_op_expr_->input_arg_tuple(); + int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); + if (tuple_index >= 0) { + return *const_cast(tensor_meta4input_index_(tuple_index)); + } + } + PRINT_BUG_PROMPT_AND_ABORT(); + return *(user_op::TensorDesc*)nullptr; + } + + user_op::TensorDesc* MutTensorDesc4ArgNameAndIndex(const std::string& name, int32_t index) { { const auto& arg_tuple = *user_op_expr_->output_arg_tuple(); int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); @@ -236,13 +254,11 @@ class UserOpExprInferContext : public user_op::InferContext { } const Shape& Shape4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { - return const_cast(this) - ->TensorDesc4ArgNameAndIndex(arg_name, index) - ->shape(); + return TensorDesc4ArgNameAndIndex(arg_name, index).shape(); } Shape* MutShape4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { - return TensorDesc4ArgNameAndIndex(arg_name, index)->mut_shape(); + return MutTensorDesc4ArgNameAndIndex(arg_name, index)->mut_shape(); } const Stride& InputStride(const std::string& name, int32_t index) const override { @@ -267,13 +283,11 @@ class UserOpExprInferContext : public user_op::InferContext { } const Stride& Stride4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { - return const_cast(this) - ->TensorDesc4ArgNameAndIndex(arg_name, index) - ->stride(); + return TensorDesc4ArgNameAndIndex(arg_name, index).stride(); } Stride* MutStride4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { - return TensorDesc4ArgNameAndIndex(arg_name, index)->mut_stride(); + return MutTensorDesc4ArgNameAndIndex(arg_name, index)->mut_stride(); } DataType InputDType(const std::string& arg_name, int32_t index) const override { @@ -286,12 +300,10 @@ class UserOpExprInferContext : public user_op::InferContext { return MutDtype4ArgNameAndIndex(arg_name, index); } DataType Dtype4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { - return const_cast(this) - ->TensorDesc4ArgNameAndIndex(arg_name, index) - ->data_type(); + return TensorDesc4ArgNameAndIndex(arg_name, index).data_type(); } DataType* MutDtype4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { - return TensorDesc4ArgNameAndIndex(arg_name, index)->mut_data_type(); + return MutTensorDesc4ArgNameAndIndex(arg_name, index)->mut_data_type(); } bool InputIsDynamic(const std::string& arg_name, int32_t index) const override { return IsDynamic4ArgNameAndIndex(arg_name, index); @@ -303,12 +315,10 @@ class UserOpExprInferContext : public user_op::InferContext { return MutIsDynamic4ArgNameAndIndex(arg_name, index); } bool IsDynamic4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { - return const_cast(this) - ->TensorDesc4ArgNameAndIndex(arg_name, index) - ->is_dynamic(); + return TensorDesc4ArgNameAndIndex(arg_name, index).is_dynamic(); } bool* MutIsDynamic4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { - return TensorDesc4ArgNameAndIndex(arg_name, index)->mut_is_dynamic(); + return MutTensorDesc4ArgNameAndIndex(arg_name, index)->mut_is_dynamic(); } const std::string& input(const std::string& arg_name, int32_t index) const override { const auto& arg_tuple = *user_op_expr_->input_arg_tuple(); @@ -414,18 +424,16 @@ class UserOpExprLogicalInferContext final : public UserOpExprInferContext { const ParallelDesc& parallel_desc() const override { return *parallel_desc_; } const SbpParallel& SbpParallel4ArgNameAndIndex(const std::string& name, int32_t index) const override { - auto* tensor_meta = dynamic_cast( - const_cast(this)->TensorDesc4ArgNameAndIndex(name, index)); - CHECK_NOTNULL(tensor_meta); - Symbol nd_sbp = tensor_meta->nd_sbp(); + const GlobalTensorMeta& tensor_meta = + dynamic_cast(TensorDesc4ArgNameAndIndex(name, index)); + Symbol nd_sbp = tensor_meta.nd_sbp(); CHECK_EQ(nd_sbp->sbp_parallel_size(), 1); return nd_sbp->sbp_parallel(0); } const NdSbp& NdSbp4ArgNameAndIndex(const std::string& name, int32_t index) const override { - auto* tensor_meta = dynamic_cast( - const_cast(this)->TensorDesc4ArgNameAndIndex(name, index)); - CHECK_NOTNULL(tensor_meta); - return *tensor_meta->nd_sbp(); + const GlobalTensorMeta& tensor_meta = + dynamic_cast(TensorDesc4ArgNameAndIndex(name, index)); + return *tensor_meta.nd_sbp(); } int64_t parallel_num() const override { return parallel_desc_->parallel_num(); } diff --git a/oneflow/core/kernel/user_kernel.cpp b/oneflow/core/kernel/user_kernel.cpp index e4f2616598a..9668ef2b0f1 100644 --- a/oneflow/core/kernel/user_kernel.cpp +++ b/oneflow/core/kernel/user_kernel.cpp @@ -249,13 +249,21 @@ class UserKernelOpInferContext : public user_op::InferContext { const user_op::TensorDesc& InputTensorDesc(const std::string& arg_name, int32_t index) const override { - return *const_cast(this)->TensorDesc4ArgNameAndIndex(arg_name, - index); + return TensorDesc4ArgNameAndIndex(arg_name, index); } user_op::TensorDesc* OutputTensorDesc(const std::string& arg_name, int32_t index) override { - return TensorDesc4ArgNameAndIndex(arg_name, index); + return MutTensorDesc4ArgNameAndIndex(arg_name, index); + } + const user_op::TensorDesc& TensorDesc4ArgNameAndIndex(const std::string& arg_name, + int32_t index) const { + auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index)); + if (it == arg2tensor_desc_.end()) { + PRINT_BUG_PROMPT_AND_ABORT(); + return *(user_op::TensorDesc*)nullptr; + } + return *it->second; } - user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) { + user_op::TensorDesc* MutTensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) { auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index)); if (it == arg2tensor_desc_.end()) { return nullptr; } return it->second.get(); @@ -270,12 +278,10 @@ class UserKernelOpInferContext : public user_op::InferContext { return MutShape4ArgNameAndIndex(arg_name, index); } const Shape& Shape4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { - return const_cast(this) - ->TensorDesc4ArgNameAndIndex(arg_name, index) - ->shape(); + return TensorDesc4ArgNameAndIndex(arg_name, index).shape(); } Shape* MutShape4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { - return TensorDesc4ArgNameAndIndex(arg_name, index)->mut_shape(); + return MutTensorDesc4ArgNameAndIndex(arg_name, index)->mut_shape(); } const Stride& InputStride(const std::string& arg_name, int32_t index) const override { return Stride4ArgNameAndIndex(arg_name, index); @@ -287,12 +293,10 @@ class UserKernelOpInferContext : public user_op::InferContext { return MutStride4ArgNameAndIndex(arg_name, index); } const Stride& Stride4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { - return const_cast(this) - ->TensorDesc4ArgNameAndIndex(arg_name, index) - ->stride(); + return TensorDesc4ArgNameAndIndex(arg_name, index).stride(); } Stride* MutStride4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { - return TensorDesc4ArgNameAndIndex(arg_name, index)->mut_stride(); + return MutTensorDesc4ArgNameAndIndex(arg_name, index)->mut_stride(); } DataType InputDType(const std::string& arg_name, int32_t index) const override { return Dtype4ArgNameAndIndex(arg_name, index); @@ -304,12 +308,10 @@ class UserKernelOpInferContext : public user_op::InferContext { return MutDtype4ArgNameAndIndex(arg_name, index); } DataType Dtype4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { - return const_cast(this) - ->TensorDesc4ArgNameAndIndex(arg_name, index) - ->data_type(); + return TensorDesc4ArgNameAndIndex(arg_name, index).data_type(); } DataType* MutDtype4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { - return TensorDesc4ArgNameAndIndex(arg_name, index)->mut_data_type(); + return MutTensorDesc4ArgNameAndIndex(arg_name, index)->mut_data_type(); } bool InputIsDynamic(const std::string& arg_name, int32_t index) const override { return IsDynamic4ArgNameAndIndex(arg_name, index); @@ -321,12 +323,10 @@ class UserKernelOpInferContext : public user_op::InferContext { return MutIsDynamic4ArgNameAndIndex(arg_name, index); } bool IsDynamic4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { - return const_cast(this) - ->TensorDesc4ArgNameAndIndex(arg_name, index) - ->is_dynamic(); + return TensorDesc4ArgNameAndIndex(arg_name, index).is_dynamic(); } bool* MutIsDynamic4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { - return TensorDesc4ArgNameAndIndex(arg_name, index)->mut_is_dynamic(); + return MutTensorDesc4ArgNameAndIndex(arg_name, index)->mut_is_dynamic(); } const ArgVec& inputs() const override { return inputs_; } diff --git a/oneflow/core/operator/user_op.cpp b/oneflow/core/operator/user_op.cpp index f58fcebdf8c..4ed1a75a096 100644 --- a/oneflow/core/operator/user_op.cpp +++ b/oneflow/core/operator/user_op.cpp @@ -147,12 +147,21 @@ class UserOpInferContext final : public user_op::InferContext { const user_op::TensorDesc& InputTensorDesc(const std::string& arg_name, int32_t index) const override { - return *const_cast(this)->TensorDesc4ArgNameAndIndex(arg_name, index); + return TensorDesc4ArgNameAndIndex(arg_name, index); } user_op::TensorDesc* OutputTensorDesc(const std::string& arg_name, int32_t index) override { - return TensorDesc4ArgNameAndIndex(arg_name, index); + return MutTensorDesc4ArgNameAndIndex(arg_name, index); + } + const user_op::TensorDesc& TensorDesc4ArgNameAndIndex(const std::string& arg_name, + int32_t index) const { + auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index)); + if (it == arg2tensor_desc_.end()) { + PRINT_BUG_PROMPT_AND_ABORT(); + return *(user_op::TensorDesc*)nullptr; + }; + return it->second; } - user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) { + user_op::TensorDesc* MutTensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) { auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index)); if (it == arg2tensor_desc_.end()) { return nullptr; }; return &(it->second); diff --git a/oneflow/user/kernels/stateful_opkernel.cpp b/oneflow/user/kernels/stateful_opkernel.cpp index f37f8771aec..853b4c92c71 100644 --- a/oneflow/user/kernels/stateful_opkernel.cpp +++ b/oneflow/user/kernels/stateful_opkernel.cpp @@ -51,9 +51,20 @@ class ZeroCopyBaseContextHelper { index); \ if (i >= 0) { return (outputs).at(i) post_action; } - user_op::TensorDesc* TensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, - const std::string& arg_name, - const int32_t index) const { + const user_op::TensorDesc& TensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, + const std::string& arg_name, + const int32_t index) const { + int32_t inedx = TryGetTensorTupleIndex(input_arg_tuple_->arg_name2bn_index2tensor_tuple_index(), + arg_name, index); + if (inedx >= 0) { return *call_ctx->inputs().at(inedx); } + inedx = TryGetTensorTupleIndex(output_arg_tuple_->arg_name2bn_index2tensor_tuple_index(), + arg_name, index); + if (inedx >= 0) { return *call_ctx->outputs().at(inedx); } + return *(user_op::TensorDesc*)nullptr; + } + user_op::TensorDesc* MutTensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, + const std::string& arg_name, + const int32_t index) const { RETURN_IF_FOUND(call_ctx->inputs(), call_ctx->outputs(), .get()); return nullptr; } @@ -159,18 +170,23 @@ class UserOpInferContextHelper final { const user_op::TensorDesc& InputTensorDesc(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { - return *CHECK_NOTNULL(TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index)); + return *CHECK_NOTNULL(MutTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index)); } user_op::TensorDesc* OutputTensorDesc(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { - return TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index); + return MutTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index); } - user_op::TensorDesc* TensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, - const std::string& arg_name, - int32_t index) const { + const user_op::TensorDesc& TensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, + const std::string& arg_name, + int32_t index) const { return zero_copy_base_ctx_helper_.TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index); } + user_op::TensorDesc* MutTensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, + const std::string& arg_name, + int32_t index) const { + return zero_copy_base_ctx_helper_.MutTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index); + } const Shape& InputShape(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { @@ -186,11 +202,11 @@ class UserOpInferContextHelper final { } const Shape& Shape4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { - return NonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index)->shape(); + return NonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index).shape(); } Shape* MutShape4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { - return NonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index)->mut_shape(); + return MutNonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index)->mut_shape(); } const Stride& InputStride(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { @@ -206,11 +222,11 @@ class UserOpInferContextHelper final { } const Stride& Stride4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { - return NonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index)->stride(); + return NonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index).stride(); } Stride* MutStride4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { - return NonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index)->mut_stride(); + return MutNonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index)->mut_stride(); } DataType InputDType(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { @@ -226,11 +242,11 @@ class UserOpInferContextHelper final { } DataType Dtype4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { - return NonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index)->data_type(); + return NonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index).data_type(); } DataType* MutDtype4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { - return NonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index)->mut_data_type(); + return MutNonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index)->mut_data_type(); } bool InputIsDynamic(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { @@ -246,11 +262,11 @@ class UserOpInferContextHelper final { } bool IsDynamic4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { - return NonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index)->is_dynamic(); + return NonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index).is_dynamic(); } bool* MutIsDynamic4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { - return NonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index)->mut_is_dynamic(); + return MutNonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index)->mut_is_dynamic(); } const ArgVec& inputs() const { return zero_copy_base_ctx_helper_.inputs(); } @@ -311,10 +327,17 @@ class UserOpInferContextHelper final { } private: - user_op::TensorDesc* NonNullTensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, - const std::string& arg_name, - int32_t index) const { - user_op::TensorDesc* tensor_desc = TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index); + const user_op::TensorDesc& NonNullTensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, + const std::string& arg_name, + int32_t index) const { + const user_op::TensorDesc& tensor_desc = TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index); + if (!(&tensor_desc)) { LOG(FATAL) << "Arg (" << arg_name << "," << index << ") is not found"; } + return tensor_desc; + } + user_op::TensorDesc* MutNonNullTensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, + const std::string& arg_name, + int32_t index) const { + user_op::TensorDesc* tensor_desc = MutTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index); if (!tensor_desc) { LOG(FATAL) << "Arg (" << arg_name << "," << index << ") is not found"; } return tensor_desc; } @@ -342,9 +365,13 @@ class UserOpInferContext : public user_op::InferContext { user_op::TensorDesc* OutputTensorDesc(const std::string& arg_name, int32_t index) override { return helper_->OutputTensorDesc(call_ctx_, arg_name, index); } - user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) { + const user_op::TensorDesc& TensorDesc4ArgNameAndIndex(const std::string& arg_name, + int32_t index) const { return helper_->TensorDesc4ArgNameAndIndex(call_ctx_, arg_name, index); } + user_op::TensorDesc* MutTensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) { + return helper_->MutTensorDesc4ArgNameAndIndex(call_ctx_, arg_name, index); + } const Shape& InputShape(const std::string& arg_name, int32_t index) const override { return helper_->InputShape(call_ctx_, arg_name, index); @@ -465,12 +492,18 @@ class UserKernelComputeContextHelper final { ~UserKernelComputeContextHelper() = default; - const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, + const user_op::TensorDesc& TensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { return base_ctx_helper_.TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index); } + const user_op::TensorDesc* MutTensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, + const std::string& arg_name, + int32_t index) const { + return base_ctx_helper_.MutTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index); + } + user_op::Tensor* Tensor4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { return base_ctx_helper_.Tensor4ArgNameAndIndex(call_ctx, arg_name, index); @@ -505,7 +538,7 @@ class UserKernelComputeContext final : public user_op::KernelComputeContext { const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { - return helper_->TensorDesc4ArgNameAndIndex(call_ctx_, arg_name, index); + return helper_->MutTensorDesc4ArgNameAndIndex(call_ctx_, arg_name, index); } user_op::Tensor* Tensor4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { @@ -552,11 +585,16 @@ class UserKernelRegContextHelper final { const ParallelContext& parallel_ctx(eager::CallContext* call_ctx) const { return base_ctx_helper_.parallel_ctx(call_ctx); } - const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, + const user_op::TensorDesc& TensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { return base_ctx_helper_.TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index); } + const user_op::TensorDesc* MutTensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, + const std::string& arg_name, + int32_t index) const { + return base_ctx_helper_.MutTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index); + } const ArgVec& inputs() const { return base_ctx_helper_.inputs(); } const ArgVec& outputs() const { return base_ctx_helper_.outputs(); } @@ -582,7 +620,7 @@ class UserKernelRegContext final : public user_op::KernelRegContext { const ParallelContext& parallel_ctx() const override { return helper_->parallel_ctx(call_ctx_); } const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { - return helper_->TensorDesc4ArgNameAndIndex(call_ctx_, arg_name, index); + return helper_->MutTensorDesc4ArgNameAndIndex(call_ctx_, arg_name, index); } const ArgVec& inputs() const override { return helper_->inputs(); } const ArgVec& outputs() const override { return helper_->outputs(); } @@ -616,11 +654,16 @@ class UserKernelInitAndCacheContextHelper final { const ParallelContext& parallel_ctx(eager::CallContext* call_ctx) const { return base_ctx_helper_.parallel_ctx(call_ctx); } - const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, + const user_op::TensorDesc& TensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { return base_ctx_helper_.TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index); } + const user_op::TensorDesc* MutTensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, + const std::string& arg_name, + int32_t index) const { + return base_ctx_helper_.MutTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index); + } const user_op::TensorDesc* LogicalTensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { @@ -676,7 +719,7 @@ class UserKernelInitAndCacheContext final : public user_op::KernelInitContext, const ParallelContext& parallel_ctx() const override { return helper_->parallel_ctx(call_ctx_); } const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { - return helper_->TensorDesc4ArgNameAndIndex(call_ctx_, arg_name, index); + return helper_->MutTensorDesc4ArgNameAndIndex(call_ctx_, arg_name, index); } const user_op::TensorDesc* LogicalTensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { From 9ffb9aee441d4ce9a15ec37589795b509262427c Mon Sep 17 00:00:00 2001 From: clackhan Date: Thu, 21 Jul 2022 19:50:44 +0800 Subject: [PATCH 49/67] refine --- oneflow/user/kernels/stateful_opkernel.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/oneflow/user/kernels/stateful_opkernel.cpp b/oneflow/user/kernels/stateful_opkernel.cpp index 853b4c92c71..97fe4b36ca3 100644 --- a/oneflow/user/kernels/stateful_opkernel.cpp +++ b/oneflow/user/kernels/stateful_opkernel.cpp @@ -538,7 +538,7 @@ class UserKernelComputeContext final : public user_op::KernelComputeContext { const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { - return helper_->MutTensorDesc4ArgNameAndIndex(call_ctx_, arg_name, index); + return &helper_->TensorDesc4ArgNameAndIndex(call_ctx_, arg_name, index); } user_op::Tensor* Tensor4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { @@ -620,7 +620,7 @@ class UserKernelRegContext final : public user_op::KernelRegContext { const ParallelContext& parallel_ctx() const override { return helper_->parallel_ctx(call_ctx_); } const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { - return helper_->MutTensorDesc4ArgNameAndIndex(call_ctx_, arg_name, index); + return &helper_->TensorDesc4ArgNameAndIndex(call_ctx_, arg_name, index); } const ArgVec& inputs() const override { return helper_->inputs(); } const ArgVec& outputs() const override { return helper_->outputs(); } From 86931104a1cca1397588e30d342a4c4990cee89a Mon Sep 17 00:00:00 2001 From: clackhan Date: Fri, 22 Jul 2022 09:29:08 +0800 Subject: [PATCH 50/67] minor fix --- oneflow/user/kernels/stateful_opkernel.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/oneflow/user/kernels/stateful_opkernel.cpp b/oneflow/user/kernels/stateful_opkernel.cpp index 97fe4b36ca3..e12a35b0364 100644 --- a/oneflow/user/kernels/stateful_opkernel.cpp +++ b/oneflow/user/kernels/stateful_opkernel.cpp @@ -170,7 +170,7 @@ class UserOpInferContextHelper final { const user_op::TensorDesc& InputTensorDesc(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { - return *CHECK_NOTNULL(MutTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index)); + return TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index); } user_op::TensorDesc* OutputTensorDesc(eager::CallContext* call_ctx, const std::string& arg_name, From 190dba7c68404dc3187ebc409a4a3453372cdb94 Mon Sep 17 00:00:00 2001 From: clackhan Date: Fri, 22 Jul 2022 09:55:50 +0800 Subject: [PATCH 51/67] fix merge error --- oneflow/core/framework/op_expr.cpp | 534 ++++++++++++++--------------- 1 file changed, 267 insertions(+), 267 deletions(-) diff --git a/oneflow/core/framework/op_expr.cpp b/oneflow/core/framework/op_expr.cpp index bb1124df288..0c8687f1611 100644 --- a/oneflow/core/framework/op_expr.cpp +++ b/oneflow/core/framework/op_expr.cpp @@ -199,300 +199,300 @@ class UserOpExprInferContext : public user_op::InferContext { } user_op::TensorDesc* MutOutputTensorDesc(const std::string& name, int32_t index) override { return MutTensorDesc4ArgNameAndIndex(name, index); + } - user_op::TensorDesc* OutputTensorDesc(const std::string& name, int32_t index) override { - return MutTensorDesc4ArgNameAndIndex(name, index); - } - - const user_op::TensorDesc& TensorDesc4ArgNameAndIndex(const std::string& name, int32_t index) - const { - { - const auto& arg_tuple = *user_op_expr_->output_arg_tuple(); - int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); - if (tuple_index >= 0) { return *tensor_meta4output_index_(tuple_index); } - } - { - const auto& arg_tuple = *user_op_expr_->input_arg_tuple(); - int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); - if (tuple_index >= 0) { - return *const_cast(tensor_meta4input_index_(tuple_index)); - } - } - PRINT_BUG_PROMPT_AND_ABORT(); - return *(user_op::TensorDesc*)nullptr; - } + user_op::TensorDesc* OutputTensorDesc(const std::string& name, int32_t index) override { + return MutTensorDesc4ArgNameAndIndex(name, index); + } - user_op::TensorDesc* MutTensorDesc4ArgNameAndIndex(const std::string& name, int32_t index) { - { - const auto& arg_tuple = *user_op_expr_->output_arg_tuple(); - int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); - if (tuple_index >= 0) { return tensor_meta4output_index_(tuple_index); } - } - { - const auto& arg_tuple = *user_op_expr_->input_arg_tuple(); - int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); - if (tuple_index >= 0) { - return const_cast(tensor_meta4input_index_(tuple_index)); - } - } - return nullptr; + const user_op::TensorDesc& TensorDesc4ArgNameAndIndex(const std::string& name, + int32_t index) const { + { + const auto& arg_tuple = *user_op_expr_->output_arg_tuple(); + int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); + if (tuple_index >= 0) { return *tensor_meta4output_index_(tuple_index); } } - - const Shape& InputShape(const std::string& name, int32_t index) const override { + { const auto& arg_tuple = *user_op_expr_->input_arg_tuple(); int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); - CHECK_GE(tuple_index, 0); - return tensor_meta4input_index_(tuple_index)->shape(); + if (tuple_index >= 0) { + return *const_cast(tensor_meta4input_index_(tuple_index)); + } } + PRINT_BUG_PROMPT_AND_ABORT(); + return *(user_op::TensorDesc*)nullptr; + } - const Shape& OutputShape(const std::string& name, int32_t index) const override { + user_op::TensorDesc* MutTensorDesc4ArgNameAndIndex(const std::string& name, int32_t index) { + { const auto& arg_tuple = *user_op_expr_->output_arg_tuple(); int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); - CHECK_GE(tuple_index, 0); - return tensor_meta4input_index_(tuple_index)->shape(); + if (tuple_index >= 0) { return tensor_meta4output_index_(tuple_index); } } - - Shape* MutOutputShape(const std::string& name, int32_t index) override { - const auto& arg_tuple = *user_op_expr_->output_arg_tuple(); + { + const auto& arg_tuple = *user_op_expr_->input_arg_tuple(); int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); - CHECK_GE(tuple_index, 0); - return tensor_meta4output_index_(tuple_index)->mut_shape(); + if (tuple_index >= 0) { + return const_cast(tensor_meta4input_index_(tuple_index)); + } } + return nullptr; + } - const Shape& Shape4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { - return TensorDesc4ArgNameAndIndex(arg_name, index).shape(); - } + const Shape& InputShape(const std::string& name, int32_t index) const override { + const auto& arg_tuple = *user_op_expr_->input_arg_tuple(); + int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); + CHECK_GE(tuple_index, 0); + return tensor_meta4input_index_(tuple_index)->shape(); + } - Shape* MutShape4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { - return MutTensorDesc4ArgNameAndIndex(arg_name, index)->mut_shape(); - } + const Shape& OutputShape(const std::string& name, int32_t index) const override { + const auto& arg_tuple = *user_op_expr_->output_arg_tuple(); + int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); + CHECK_GE(tuple_index, 0); + return tensor_meta4input_index_(tuple_index)->shape(); + } - const Stride& InputStride(const std::string& name, int32_t index) const override { - const auto& arg_tuple = *user_op_expr_->input_arg_tuple(); - int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); - CHECK_GE(tuple_index, 0); - return tensor_meta4input_index_(tuple_index)->stride(); - } + Shape* MutOutputShape(const std::string& name, int32_t index) override { + const auto& arg_tuple = *user_op_expr_->output_arg_tuple(); + int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); + CHECK_GE(tuple_index, 0); + return tensor_meta4output_index_(tuple_index)->mut_shape(); + } - const Stride& OutputStride(const std::string& name, int32_t index) const override { - const auto& arg_tuple = *user_op_expr_->input_arg_tuple(); - int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); - CHECK_GE(tuple_index, 0); - return tensor_meta4input_index_(tuple_index)->stride(); - } + const Shape& Shape4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { + return TensorDesc4ArgNameAndIndex(arg_name, index).shape(); + } - Stride* MutOutputStride(const std::string& name, int32_t index) override { - const auto& arg_tuple = *user_op_expr_->output_arg_tuple(); - int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); - CHECK_GE(tuple_index, 0); - return tensor_meta4output_index_(tuple_index)->mut_stride(); - } + Shape* MutShape4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { + return MutTensorDesc4ArgNameAndIndex(arg_name, index)->mut_shape(); + } - const Stride& Stride4ArgNameAndIndex(const std::string& arg_name, int32_t index) - const override { - return TensorDesc4ArgNameAndIndex(arg_name, index).stride(); - } + const Stride& InputStride(const std::string& name, int32_t index) const override { + const auto& arg_tuple = *user_op_expr_->input_arg_tuple(); + int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); + CHECK_GE(tuple_index, 0); + return tensor_meta4input_index_(tuple_index)->stride(); + } - Stride* MutStride4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { - return MutTensorDesc4ArgNameAndIndex(arg_name, index)->mut_stride(); - } + const Stride& OutputStride(const std::string& name, int32_t index) const override { + const auto& arg_tuple = *user_op_expr_->input_arg_tuple(); + int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); + CHECK_GE(tuple_index, 0); + return tensor_meta4input_index_(tuple_index)->stride(); + } - DataType InputDType(const std::string& arg_name, int32_t index) const override { - return Dtype4ArgNameAndIndex(arg_name, index); - } - DataType OutputDType(const std::string& arg_name, int32_t index) const override { - return Dtype4ArgNameAndIndex(arg_name, index); - } - DataType* MutOutputDType(const std::string& arg_name, int32_t index) override { - return MutDtype4ArgNameAndIndex(arg_name, index); - } - DataType Dtype4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { - return TensorDesc4ArgNameAndIndex(arg_name, index).data_type(); - } - DataType* MutDtype4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { - return MutTensorDesc4ArgNameAndIndex(arg_name, index)->mut_data_type(); - } - bool InputIsDynamic(const std::string& arg_name, int32_t index) const override { - return IsDynamic4ArgNameAndIndex(arg_name, index); - } - bool OutputIsDynamic(const std::string& arg_name, int32_t index) const override { - return IsDynamic4ArgNameAndIndex(arg_name, index); - } - bool* MutOutputIsDynamic(const std::string& arg_name, int32_t index) override { - return MutIsDynamic4ArgNameAndIndex(arg_name, index); - } - bool IsDynamic4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { - return TensorDesc4ArgNameAndIndex(arg_name, index).is_dynamic(); - } - bool* MutIsDynamic4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { - return MutTensorDesc4ArgNameAndIndex(arg_name, index)->mut_is_dynamic(); - } - const std::string& input(const std::string& arg_name, int32_t index) const override { - const auto& arg_tuple = *user_op_expr_->input_arg_tuple(); - int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(arg_name, index); - CHECK_GE(tuple_index, 0); - return arg_tuple.indexed_bns().at(tuple_index); - } - const std::string& output(const std::string& arg_name, int32_t index) const override { - const auto& arg_tuple = *user_op_expr_->output_arg_tuple(); - int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(arg_name, index); - CHECK_GE(tuple_index, 0); - return arg_tuple.indexed_bns().at(tuple_index); - } - bool has_input(const std::string& arg_name, int32_t index) const override { - const auto& arg_tuple = *user_op_expr_->input_arg_tuple(); - int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(arg_name, index); - return tuple_index >= 0; - } - bool has_output(const std::string& arg_name, int32_t index) const override { - const auto& arg_tuple = *user_op_expr_->output_arg_tuple(); - int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(arg_name, index); - return tuple_index >= 0; - } - int32_t input_size(const std::string& arg_name) const override { - const auto& arg_tuple = *user_op_expr_->input_arg_tuple(); - return arg_tuple.arg_name2bn_index2tensor_tuple_index().at(arg_name).size(); - } - int32_t output_size(const std::string& arg_name) const override { - const auto& arg_tuple = *user_op_expr_->output_arg_tuple(); - return arg_tuple.arg_name2bn_index2tensor_tuple_index().at(arg_name).size(); - } - const std::string& op_name() const override { return user_op_expr_->op_name(); } - const std::string& op_type_name() const override { return user_op_expr_->op_type_name(); } - const std::string& op_loc() const override { return loc_; } - - private: - const std::shared_ptr& Attr4Name(const std::string& attr_name) - const override { - return composed_attrs_.Attr4Name(attr_name); - } - const UserOpExpr* user_op_expr_; - const ComposedAttrMap composed_attrs_; - const std::function& tensor_meta4input_index_; - const std::function& tensor_meta4output_index_; - std::string loc_; - }; - - class UserOpExprPhysicalInferContext final : public UserOpExprInferContext { - public: - using UserOpExprInferContext::UserOpExprInferContext; - ~UserOpExprPhysicalInferContext() override = default; - - const user_op::TensorDesc* LogicalTensorDesc4ArgNameAndIndex(const std::string& name, - int32_t index) const override { - UNIMPLEMENTED(); - return nullptr; - } + Stride* MutOutputStride(const std::string& name, int32_t index) override { + const auto& arg_tuple = *user_op_expr_->output_arg_tuple(); + int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); + CHECK_GE(tuple_index, 0); + return tensor_meta4output_index_(tuple_index)->mut_stride(); + } - const ParallelContext& parallel_ctx() const override { - UNIMPLEMENTED(); - return *(const ParallelContext*)nullptr; - } - const ParallelDesc& parallel_desc() const override { - UNIMPLEMENTED(); - return *(const ParallelDesc*)nullptr; - } - const SbpParallel& SbpParallel4ArgNameAndIndex(const std::string&, int32_t) const override { - UNIMPLEMENTED(); - return *(const SbpParallel*)nullptr; - } - const NdSbp& NdSbp4ArgNameAndIndex(const std::string&, int32_t) const override { - UNIMPLEMENTED(); - return *(const NdSbp*)nullptr; - } - int64_t parallel_num() const override { return 1; } - }; - - class UserOpExprLogicalInferContext final : public UserOpExprInferContext { - public: - UserOpExprLogicalInferContext( - const UserOpExpr* user_op_expr, const AttrMap& attrs, Symbol parallel_desc, - const std::function& TensorMeta4InputIndex, - const std::function& TensorMeta4OutputIndex) - : UserOpExprInferContext(user_op_expr, attrs, parallel_desc->device_tag(), - TensorMeta4InputIndex, TensorMeta4OutputIndex), - parallel_desc_(parallel_desc) { - const auto& opt_parallel_id = CHECK_JUST(GetParallelId4CurrentProcessCtx(parallel_desc_)); - // Default parallel_id = -1, which will not cause bad effects becauce it will never be used in - // LogicalTensorDescInfer. - int64_t parallel_id = -1; - if (opt_parallel_id->has_value()) { parallel_id = CHECK_JUST(*opt_parallel_id); } - parallel_ctx_.set_parallel_id(parallel_id); - parallel_ctx_.set_parallel_num(parallel_desc_->parallel_num()); - } - ~UserOpExprLogicalInferContext() override = default; + const Stride& Stride4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { + return TensorDesc4ArgNameAndIndex(arg_name, index).stride(); + } - const user_op::TensorDesc* LogicalTensorDesc4ArgNameAndIndex(const std::string& name, - int32_t index) const override { - UNIMPLEMENTED(); - } + Stride* MutStride4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { + return MutTensorDesc4ArgNameAndIndex(arg_name, index)->mut_stride(); + } - const ParallelContext& parallel_ctx() const override { return parallel_ctx_; } - const ParallelDesc& parallel_desc() const override { return *parallel_desc_; } - const SbpParallel& SbpParallel4ArgNameAndIndex(const std::string& name, - int32_t index) const override { - const GlobalTensorMeta& tensor_meta = - dynamic_cast(TensorDesc4ArgNameAndIndex(name, index)); - Symbol nd_sbp = tensor_meta.nd_sbp(); - CHECK_EQ(nd_sbp->sbp_parallel_size(), 1); - return nd_sbp->sbp_parallel(0); - } - const NdSbp& NdSbp4ArgNameAndIndex(const std::string& name, int32_t index) const override { - const GlobalTensorMeta& tensor_meta = - dynamic_cast(TensorDesc4ArgNameAndIndex(name, index)); - return *tensor_meta.nd_sbp(); - } - int64_t parallel_num() const override { return parallel_desc_->parallel_num(); } - - private: - Symbol parallel_desc_; - ParallelContext parallel_ctx_; - }; - - class UserOpExprDeviceAndStreamInferContext final : public user_op::DeviceAndStreamInferContext { - public: - UserOpExprDeviceAndStreamInferContext(const UserOpExpr* user_op_expr, const AttrMap& attrs, - const TensorTuple& input_tensors, - TensorTuple* output_tensors) - : user_op_expr_(user_op_expr), - composed_attrs_(attrs, user_op_expr->base_attrs()), - input_tensors_(&input_tensors), - output_tensors_(output_tensors) {} - - const std::vector>& inputs() const override { - return user_op_expr_->indexed_input_pairs(); - } + DataType InputDType(const std::string& arg_name, int32_t index) const override { + return Dtype4ArgNameAndIndex(arg_name, index); + } + DataType OutputDType(const std::string& arg_name, int32_t index) const override { + return Dtype4ArgNameAndIndex(arg_name, index); + } + DataType* MutOutputDType(const std::string& arg_name, int32_t index) override { + return MutDtype4ArgNameAndIndex(arg_name, index); + } + DataType Dtype4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { + return TensorDesc4ArgNameAndIndex(arg_name, index).data_type(); + } + DataType* MutDtype4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { + return MutTensorDesc4ArgNameAndIndex(arg_name, index)->mut_data_type(); + } + bool InputIsDynamic(const std::string& arg_name, int32_t index) const override { + return IsDynamic4ArgNameAndIndex(arg_name, index); + } + bool OutputIsDynamic(const std::string& arg_name, int32_t index) const override { + return IsDynamic4ArgNameAndIndex(arg_name, index); + } + bool* MutOutputIsDynamic(const std::string& arg_name, int32_t index) override { + return MutIsDynamic4ArgNameAndIndex(arg_name, index); + } + bool IsDynamic4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { + return TensorDesc4ArgNameAndIndex(arg_name, index).is_dynamic(); + } + bool* MutIsDynamic4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { + return MutTensorDesc4ArgNameAndIndex(arg_name, index)->mut_is_dynamic(); + } + const std::string& input(const std::string& arg_name, int32_t index) const override { + const auto& arg_tuple = *user_op_expr_->input_arg_tuple(); + int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(arg_name, index); + CHECK_GE(tuple_index, 0); + return arg_tuple.indexed_bns().at(tuple_index); + } + const std::string& output(const std::string& arg_name, int32_t index) const override { + const auto& arg_tuple = *user_op_expr_->output_arg_tuple(); + int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(arg_name, index); + CHECK_GE(tuple_index, 0); + return arg_tuple.indexed_bns().at(tuple_index); + } + bool has_input(const std::string& arg_name, int32_t index) const override { + const auto& arg_tuple = *user_op_expr_->input_arg_tuple(); + int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(arg_name, index); + return tuple_index >= 0; + } + bool has_output(const std::string& arg_name, int32_t index) const override { + const auto& arg_tuple = *user_op_expr_->output_arg_tuple(); + int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(arg_name, index); + return tuple_index >= 0; + } + int32_t input_size(const std::string& arg_name) const override { + const auto& arg_tuple = *user_op_expr_->input_arg_tuple(); + return arg_tuple.arg_name2bn_index2tensor_tuple_index().at(arg_name).size(); + } + int32_t output_size(const std::string& arg_name) const override { + const auto& arg_tuple = *user_op_expr_->output_arg_tuple(); + return arg_tuple.arg_name2bn_index2tensor_tuple_index().at(arg_name).size(); + } + const std::string& op_name() const override { return user_op_expr_->op_name(); } + const std::string& op_type_name() const override { return user_op_expr_->op_type_name(); } + const std::string& op_loc() const override { return loc_; } + + private: + const std::shared_ptr& Attr4Name( + const std::string& attr_name) const override { + return composed_attrs_.Attr4Name(attr_name); + } + const UserOpExpr* user_op_expr_; + const ComposedAttrMap composed_attrs_; + const std::function& tensor_meta4input_index_; + const std::function& tensor_meta4output_index_; + std::string loc_; +}; + +class UserOpExprPhysicalInferContext final : public UserOpExprInferContext { + public: + using UserOpExprInferContext::UserOpExprInferContext; + ~UserOpExprPhysicalInferContext() override = default; - const std::vector>& outputs() const override { - return user_op_expr_->indexed_output_pairs(); - } + const user_op::TensorDesc* LogicalTensorDesc4ArgNameAndIndex(const std::string& name, + int32_t index) const override { + UNIMPLEMENTED(); + return nullptr; + } - Symbol* OutputTensorDevice4ArgNameAndIndex(const std::string& name, - int64_t index) override { - const auto& arg_tuple = *user_op_expr_->output_arg_tuple(); - int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); - CHECK_GE(tuple_index, 0); - return CHECK_JUST(output_tensors_->at(tuple_index)->mut_device()); - } + const ParallelContext& parallel_ctx() const override { + UNIMPLEMENTED(); + return *(const ParallelContext*)nullptr; + } + const ParallelDesc& parallel_desc() const override { + UNIMPLEMENTED(); + return *(const ParallelDesc*)nullptr; + } + const SbpParallel& SbpParallel4ArgNameAndIndex(const std::string&, int32_t) const override { + UNIMPLEMENTED(); + return *(const SbpParallel*)nullptr; + } + const NdSbp& NdSbp4ArgNameAndIndex(const std::string&, int32_t) const override { + UNIMPLEMENTED(); + return *(const NdSbp*)nullptr; + } + int64_t parallel_num() const override { return 1; } +}; - Symbol InputTensorDevice4ArgNameAndIndex(const std::string& name, - int64_t index) const override { - const auto& arg_tuple = *user_op_expr_->input_arg_tuple(); - int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); - CHECK_GE(tuple_index, 0); - return CHECK_JUST(input_tensors_->at(tuple_index)->device()); - } +class UserOpExprLogicalInferContext final : public UserOpExprInferContext { + public: + UserOpExprLogicalInferContext( + const UserOpExpr* user_op_expr, const AttrMap& attrs, Symbol parallel_desc, + const std::function& TensorMeta4InputIndex, + const std::function& TensorMeta4OutputIndex) + : UserOpExprInferContext(user_op_expr, attrs, parallel_desc->device_tag(), + TensorMeta4InputIndex, TensorMeta4OutputIndex), + parallel_desc_(parallel_desc) { + const auto& opt_parallel_id = CHECK_JUST(GetParallelId4CurrentProcessCtx(parallel_desc_)); + // Default parallel_id = -1, which will not cause bad effects becauce it will never be used in + // LogicalTensorDescInfer. + int64_t parallel_id = -1; + if (opt_parallel_id->has_value()) { parallel_id = CHECK_JUST(*opt_parallel_id); } + parallel_ctx_.set_parallel_id(parallel_id); + parallel_ctx_.set_parallel_num(parallel_desc_->parallel_num()); + } + ~UserOpExprLogicalInferContext() override = default; - private: - const std::shared_ptr& Attr4Name( - const std::string& attr_name) const override { - return composed_attrs_.Attr4Name(attr_name); - } - const UserOpExpr* user_op_expr_; - const ComposedAttrMap composed_attrs_; - const TensorTuple* input_tensors_; - TensorTuple* output_tensors_; - }; + const user_op::TensorDesc* LogicalTensorDesc4ArgNameAndIndex(const std::string& name, + int32_t index) const override { + UNIMPLEMENTED(); + } + + const ParallelContext& parallel_ctx() const override { return parallel_ctx_; } + const ParallelDesc& parallel_desc() const override { return *parallel_desc_; } + const SbpParallel& SbpParallel4ArgNameAndIndex(const std::string& name, + int32_t index) const override { + const GlobalTensorMeta& tensor_meta = + dynamic_cast(TensorDesc4ArgNameAndIndex(name, index)); + Symbol nd_sbp = tensor_meta.nd_sbp(); + CHECK_EQ(nd_sbp->sbp_parallel_size(), 1); + return nd_sbp->sbp_parallel(0); + } + const NdSbp& NdSbp4ArgNameAndIndex(const std::string& name, int32_t index) const override { + const GlobalTensorMeta& tensor_meta = + dynamic_cast(TensorDesc4ArgNameAndIndex(name, index)); + return *tensor_meta.nd_sbp(); + } + int64_t parallel_num() const override { return parallel_desc_->parallel_num(); } + + private: + Symbol parallel_desc_; + ParallelContext parallel_ctx_; +}; + +class UserOpExprDeviceAndStreamInferContext final : public user_op::DeviceAndStreamInferContext { + public: + UserOpExprDeviceAndStreamInferContext(const UserOpExpr* user_op_expr, const AttrMap& attrs, + const TensorTuple& input_tensors, + TensorTuple* output_tensors) + : user_op_expr_(user_op_expr), + composed_attrs_(attrs, user_op_expr->base_attrs()), + input_tensors_(&input_tensors), + output_tensors_(output_tensors) {} + + const std::vector>& inputs() const override { + return user_op_expr_->indexed_input_pairs(); + } + + const std::vector>& outputs() const override { + return user_op_expr_->indexed_output_pairs(); + } + + Symbol* OutputTensorDevice4ArgNameAndIndex(const std::string& name, + int64_t index) override { + const auto& arg_tuple = *user_op_expr_->output_arg_tuple(); + int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); + CHECK_GE(tuple_index, 0); + return CHECK_JUST(output_tensors_->at(tuple_index)->mut_device()); + } + + Symbol InputTensorDevice4ArgNameAndIndex(const std::string& name, + int64_t index) const override { + const auto& arg_tuple = *user_op_expr_->input_arg_tuple(); + int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); + CHECK_GE(tuple_index, 0); + return CHECK_JUST(input_tensors_->at(tuple_index)->device()); + } + + private: + const std::shared_ptr& Attr4Name( + const std::string& attr_name) const override { + return composed_attrs_.Attr4Name(attr_name); + } + const UserOpExpr* user_op_expr_; + const ComposedAttrMap composed_attrs_; + const TensorTuple* input_tensors_; + TensorTuple* output_tensors_; +}; } // namespace From 543248502a8068e94f531444380744d88b8e853f Mon Sep 17 00:00:00 2001 From: clackhan Date: Fri, 22 Jul 2022 10:11:10 +0800 Subject: [PATCH 52/67] fix warning error --- oneflow/core/framework/op_expr.cpp | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/oneflow/core/framework/op_expr.cpp b/oneflow/core/framework/op_expr.cpp index 0c8687f1611..208da37219b 100644 --- a/oneflow/core/framework/op_expr.cpp +++ b/oneflow/core/framework/op_expr.cpp @@ -201,10 +201,6 @@ class UserOpExprInferContext : public user_op::InferContext { return MutTensorDesc4ArgNameAndIndex(name, index); } - user_op::TensorDesc* OutputTensorDesc(const std::string& name, int32_t index) override { - return MutTensorDesc4ArgNameAndIndex(name, index); - } - const user_op::TensorDesc& TensorDesc4ArgNameAndIndex(const std::string& name, int32_t index) const { { @@ -380,7 +376,7 @@ class UserOpExprPhysicalInferContext final : public UserOpExprInferContext { const user_op::TensorDesc* LogicalTensorDesc4ArgNameAndIndex(const std::string& name, int32_t index) const override { - UNIMPLEMENTED(); + PRINT_BUG_PROMPT_AND_ABORT(); return nullptr; } @@ -424,7 +420,8 @@ class UserOpExprLogicalInferContext final : public UserOpExprInferContext { const user_op::TensorDesc* LogicalTensorDesc4ArgNameAndIndex(const std::string& name, int32_t index) const override { - UNIMPLEMENTED(); + PRINT_BUG_PROMPT_AND_ABORT(); + return nullptr; } const ParallelContext& parallel_ctx() const override { return parallel_ctx_; } From 1b114434b7af108cbc2f21701209267446e6ebd7 Mon Sep 17 00:00:00 2001 From: clackhan Date: Fri, 22 Jul 2022 15:23:39 +0800 Subject: [PATCH 53/67] refine --- oneflow/user/kernels/stateful_opkernel.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/oneflow/user/kernels/stateful_opkernel.cpp b/oneflow/user/kernels/stateful_opkernel.cpp index e12a35b0364..675e491b4f3 100644 --- a/oneflow/user/kernels/stateful_opkernel.cpp +++ b/oneflow/user/kernels/stateful_opkernel.cpp @@ -60,6 +60,7 @@ class ZeroCopyBaseContextHelper { inedx = TryGetTensorTupleIndex(output_arg_tuple_->arg_name2bn_index2tensor_tuple_index(), arg_name, index); if (inedx >= 0) { return *call_ctx->outputs().at(inedx); } + PRINT_BUG_PROMPT_AND_ABORT(); return *(user_op::TensorDesc*)nullptr; } user_op::TensorDesc* MutTensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, From e6b6e1eaf87a153ba82a0263e27f9a00161b1e98 Mon Sep 17 00:00:00 2001 From: clackhan Date: Fri, 22 Jul 2022 17:07:45 +0800 Subject: [PATCH 54/67] fix static check error --- oneflow/core/framework/infer_util.cpp | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/oneflow/core/framework/infer_util.cpp b/oneflow/core/framework/infer_util.cpp index 27648f0d489..f63f70480f4 100644 --- a/oneflow/core/framework/infer_util.cpp +++ b/oneflow/core/framework/infer_util.cpp @@ -39,8 +39,10 @@ Maybe TensorDescInferFnUtil::Unchanged(InferContext* ctx) { } for (size_t i = 0; i < ctx->outputs().size(); ++i) { const std::pair& output_arg = ctx->outputs().at(i); - *ctx->MutOutputIsDynamic(output_arg.first, output_arg.second) = first_tensor_desc->is_dynamic(); - *ctx->MutOutputShape(output_arg.first, output_arg.second) = first_tensor_desc->shape(); + *ctx->MutOutputIsDynamic(output_arg.first, output_arg.second) = // NOLINT + first_tensor_desc->is_dynamic(); // NOLINT + *ctx->MutOutputShape(output_arg.first, output_arg.second) = // NOLINT + first_tensor_desc->shape(); // NOLINT } return Maybe::Ok(); } @@ -58,7 +60,8 @@ Maybe TensorDescInferFnUtil::UnchangedDataType(InferContext* ctx) { } for (size_t i = 0; i < ctx->outputs().size(); ++i) { const std::pair& output_arg = ctx->outputs().at(i); - *ctx->MutOutputDType(output_arg.first, output_arg.second) = first_tensor_desc->data_type(); + *ctx->MutOutputDType(output_arg.first, output_arg.second) = // NOLINT + first_tensor_desc->data_type(); // NOLINT } return Maybe::Ok(); } From ea3a604ca06cba0b18ec094d602ab1b33dd502a3 Mon Sep 17 00:00:00 2001 From: binbinHan Date: Fri, 22 Jul 2022 19:45:49 +0800 Subject: [PATCH 55/67] Update op_expr.cpp --- oneflow/core/framework/op_expr.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/oneflow/core/framework/op_expr.cpp b/oneflow/core/framework/op_expr.cpp index cf07de52b32..08dada8c639 100644 --- a/oneflow/core/framework/op_expr.cpp +++ b/oneflow/core/framework/op_expr.cpp @@ -269,7 +269,7 @@ class UserOpExprInferContext : public user_op::InferContext { } const Stride& OutputStride(const std::string& name, int32_t index) const override { - const auto& arg_tuple = *user_op_expr_->input_arg_tuple(); + const auto& arg_tuple = *user_op_expr_->output_arg_tuple(); int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); CHECK_GE(tuple_index, 0); return tensor_meta4input_index_(tuple_index)->stride(); From c8bcacbc9e53de5d7d713358342663c4cd15fdfe Mon Sep 17 00:00:00 2001 From: binbinHan Date: Fri, 22 Jul 2022 19:47:56 +0800 Subject: [PATCH 56/67] Update op_expr.cpp --- oneflow/core/framework/op_expr.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/oneflow/core/framework/op_expr.cpp b/oneflow/core/framework/op_expr.cpp index 08dada8c639..df755882ab0 100644 --- a/oneflow/core/framework/op_expr.cpp +++ b/oneflow/core/framework/op_expr.cpp @@ -272,7 +272,7 @@ class UserOpExprInferContext : public user_op::InferContext { const auto& arg_tuple = *user_op_expr_->output_arg_tuple(); int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); CHECK_GE(tuple_index, 0); - return tensor_meta4input_index_(tuple_index)->stride(); + return tensor_meta4output_index_(tuple_index)->stride(); } Stride* MutOutputStride(const std::string& name, int32_t index) override { From 8e0d38846c0ef2990d5e5ac25082e13b9dba3ac4 Mon Sep 17 00:00:00 2001 From: clackhan Date: Fri, 22 Jul 2022 20:04:02 +0800 Subject: [PATCH 57/67] split MutTensorMeta and MutLocalTensorMeta --- oneflow/core/common/tensor_meta.cpp | 31 ++++++++++++-- oneflow/core/common/tensor_meta.h | 40 +++++++++++++------ .../framework/global_tensor_infer_cache.h | 2 +- oneflow/core/framework/op_expr.cpp | 19 +++++++-- 4 files changed, 71 insertions(+), 21 deletions(-) diff --git a/oneflow/core/common/tensor_meta.cpp b/oneflow/core/common/tensor_meta.cpp index 06beb5d4262..e488bf94695 100644 --- a/oneflow/core/common/tensor_meta.cpp +++ b/oneflow/core/common/tensor_meta.cpp @@ -20,6 +20,29 @@ limitations under the License. namespace oneflow { namespace one { +MutTensorMeta::MutTensorMeta() + : TensorMeta(std::make_shared(), std::make_shared(), + kInvalidDataType) {} + +MutTensorMeta::MutTensorMeta(const std::shared_ptr& shape, DataType dtype) + : TensorMeta(shape, std::make_shared(*shape), dtype) {} + +MutTensorMeta::MutTensorMeta(const std::shared_ptr& shape, + const std::shared_ptr& stride, DataType dtype) + : TensorMeta(shape, stride, dtype) {} + +bool MutTensorMeta::operator==(const MutTensorMeta& other) const { + // It's correct to ignore is_dynamic_ field. + return *this->shape_ptr() == *other.shape_ptr() && this->dtype() == other.dtype() + && this->stride() == other.stride(); +} + +size_t MutTensorMeta::CalcHashValue() const { + // It's correct to ignore is_dynamic_ field. + return std::hash()(*shape_ptr()) ^ std::hash()(dtype()) + ^ std::hash()(stride()); +} + LocalTensorMeta::LocalTensorMeta() : TensorMeta(std::make_shared(), std::make_shared(), DataType::kInvalidDataType), @@ -51,21 +74,21 @@ size_t LocalTensorMeta::CalcHashValue() const { } MutLocalTensorMeta::MutLocalTensorMeta() - : TensorMeta(std::make_shared(), std::make_shared(), - kInvalidDataType), + : MutTensorMeta(std::make_shared(), std::make_shared(), + kInvalidDataType), device_(Symbol()), storage_offset_(0) {} MutLocalTensorMeta::MutLocalTensorMeta(const std::shared_ptr& shape, DataType dtype, Symbol device) - : TensorMeta(shape, std::make_shared(*shape), dtype), + : MutTensorMeta(shape, std::make_shared(*shape), dtype), device_(device), storage_offset_(0) {} MutLocalTensorMeta::MutLocalTensorMeta(const std::shared_ptr& shape, const std::shared_ptr& stride, DataType dtype, Symbol device, int64_t storage_offset) - : TensorMeta(shape, stride, dtype), device_(device), storage_offset_(storage_offset) {} + : MutTensorMeta(shape, stride, dtype), device_(device), storage_offset_(storage_offset) {} bool MutLocalTensorMeta::operator==(const MutLocalTensorMeta& other) const { // It's correct to ignore is_dynamic_ field. diff --git a/oneflow/core/common/tensor_meta.h b/oneflow/core/common/tensor_meta.h index 26d086cf68b..b846bcc74bb 100644 --- a/oneflow/core/common/tensor_meta.h +++ b/oneflow/core/common/tensor_meta.h @@ -100,6 +100,33 @@ class TensorMeta : public user_op::TensorDesc { bool is_dynamic_; }; +class MutTensorMeta : public TensorMeta { + public: + // uninitialized MutTensorMeta. + MutTensorMeta(); + MutTensorMeta(const MutTensorMeta&) = default; + MutTensorMeta(const std::shared_ptr& shape, DataType dtype); + MutTensorMeta(const std::shared_ptr& shape, + const std::shared_ptr& stride, DataType dtype); + virtual ~MutTensorMeta() = default; + + Shape* mut_shape() override { return const_cast(shape_.get()); } + Stride* mut_stride() override { return const_cast(stride_.get()); } + DataType* mut_data_type() override { return &data_type_; } + bool* mut_is_dynamic() override { return &is_dynamic_; } + void set_is_dynamic(bool val) override { is_dynamic_ = val; } + + void set_shape(const std::shared_ptr& val) override { shape_ = val; } + void set_stride(const std::shared_ptr& val) override { stride_ = val; } + DataType* mut_dtype() override { return &data_type_; } + void set_dtype(DataType data_type) override { data_type_ = data_type; } + + bool operator==(const MutTensorMeta& other) const; + size_t CalcHashValue() const; + + MutTensorMeta& operator=(const MutTensorMeta& other) = default; +}; + class LocalTensorMeta : public TensorMeta { public: // uninitialized LocalTensorMeta. @@ -124,7 +151,7 @@ class LocalTensorMeta : public TensorMeta { int64_t storage_offset_; }; -class MutLocalTensorMeta : public TensorMeta { +class MutLocalTensorMeta : public MutTensorMeta { public: // uninitialized MutLocalTensorMeta. MutLocalTensorMeta(); @@ -142,17 +169,6 @@ class MutLocalTensorMeta : public TensorMeta { Symbol* mut_device() { return &device_; } void set_storage_offset(int64_t offset) { storage_offset_ = offset; } - Shape* mut_shape() override { return const_cast(shape_.get()); } - Stride* mut_stride() override { return const_cast(stride_.get()); } - DataType* mut_data_type() override { return &data_type_; } - bool* mut_is_dynamic() override { return &is_dynamic_; } - void set_is_dynamic(bool val) override { is_dynamic_ = val; } - - void set_shape(const std::shared_ptr& val) override { shape_ = val; } - void set_stride(const std::shared_ptr& val) override { stride_ = val; } - DataType* mut_dtype() override { return &data_type_; } - void set_dtype(DataType data_type) override { data_type_ = data_type; } - bool operator==(const MutLocalTensorMeta& other) const; size_t CalcHashValue() const; diff --git a/oneflow/core/framework/global_tensor_infer_cache.h b/oneflow/core/framework/global_tensor_infer_cache.h index f0cee95cd87..773ac205486 100644 --- a/oneflow/core/framework/global_tensor_infer_cache.h +++ b/oneflow/core/framework/global_tensor_infer_cache.h @@ -140,7 +140,7 @@ class OpArgMutGlobalTensorMeta final { TensorMeta* mut_tensor_meta() { return &tensor_meta_; } private: - TensorMeta tensor_meta_; + MutTensorMeta tensor_meta_; }; } // namespace one diff --git a/oneflow/core/framework/op_expr.cpp b/oneflow/core/framework/op_expr.cpp index 208da37219b..65db45f869d 100644 --- a/oneflow/core/framework/op_expr.cpp +++ b/oneflow/core/framework/op_expr.cpp @@ -223,15 +223,22 @@ class UserOpExprInferContext : public user_op::InferContext { { const auto& arg_tuple = *user_op_expr_->output_arg_tuple(); int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); - if (tuple_index >= 0) { return tensor_meta4output_index_(tuple_index); } + if (tuple_index >= 0) { + TensorMeta* tensor_meta_ptr = tensor_meta4output_index_(tuple_index); + CHECK_NOTNULL(dynamic_cast(tensor_meta_ptr)); + return tensor_meta_ptr; + } } { const auto& arg_tuple = *user_op_expr_->input_arg_tuple(); int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); if (tuple_index >= 0) { - return const_cast(tensor_meta4input_index_(tuple_index)); + const TensorMeta* tensor_meta_ptr = tensor_meta4input_index_(tuple_index); + CHECK_NOTNULL(dynamic_cast(tensor_meta_ptr)); + return const_cast(tensor_meta_ptr); } } + PRINT_BUG_PROMPT_AND_ABORT(); return nullptr; } @@ -253,7 +260,9 @@ class UserOpExprInferContext : public user_op::InferContext { const auto& arg_tuple = *user_op_expr_->output_arg_tuple(); int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); CHECK_GE(tuple_index, 0); - return tensor_meta4output_index_(tuple_index)->mut_shape(); + TensorMeta* tensor_meta_ptr = tensor_meta4output_index_(tuple_index); + CHECK_NOTNULL(dynamic_cast(tensor_meta_ptr)); + return tensor_meta_ptr->mut_shape(); } const Shape& Shape4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { @@ -282,7 +291,9 @@ class UserOpExprInferContext : public user_op::InferContext { const auto& arg_tuple = *user_op_expr_->output_arg_tuple(); int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); CHECK_GE(tuple_index, 0); - return tensor_meta4output_index_(tuple_index)->mut_stride(); + TensorMeta* tensor_meta_ptr = tensor_meta4output_index_(tuple_index); + CHECK_NOTNULL(dynamic_cast(tensor_meta_ptr)); + return tensor_meta_ptr->mut_stride(); } const Stride& Stride4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { From 855ee7c1cfb1eef72be911d51f099e2d2339fc86 Mon Sep 17 00:00:00 2001 From: binbinHan Date: Fri, 22 Jul 2022 23:06:36 +0800 Subject: [PATCH 58/67] Update stateful_opkernel.cpp --- oneflow/user/kernels/stateful_opkernel.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/oneflow/user/kernels/stateful_opkernel.cpp b/oneflow/user/kernels/stateful_opkernel.cpp index 675e491b4f3..abd3caa104a 100644 --- a/oneflow/user/kernels/stateful_opkernel.cpp +++ b/oneflow/user/kernels/stateful_opkernel.cpp @@ -332,7 +332,6 @@ class UserOpInferContextHelper final { const std::string& arg_name, int32_t index) const { const user_op::TensorDesc& tensor_desc = TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index); - if (!(&tensor_desc)) { LOG(FATAL) << "Arg (" << arg_name << "," << index << ") is not found"; } return tensor_desc; } user_op::TensorDesc* MutNonNullTensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, From 767af85f317ae1932ed2f175c0d29ecc543eb7a2 Mon Sep 17 00:00:00 2001 From: clackhan Date: Sat, 23 Jul 2022 10:04:45 +0800 Subject: [PATCH 59/67] refine --- oneflow/core/framework/op_expr.cpp | 1 - oneflow/core/kernel/user_kernel.cpp | 5 +---- oneflow/core/operator/user_op.cpp | 5 +---- oneflow/user/kernels/stateful_opkernel.cpp | 22 ++++------------------ 4 files changed, 6 insertions(+), 27 deletions(-) diff --git a/oneflow/core/framework/op_expr.cpp b/oneflow/core/framework/op_expr.cpp index df755882ab0..c83fcd4c5e2 100644 --- a/oneflow/core/framework/op_expr.cpp +++ b/oneflow/core/framework/op_expr.cpp @@ -212,7 +212,6 @@ class UserOpExprInferContext : public user_op::InferContext { return *const_cast(tensor_meta4input_index_(tuple_index)); } } - PRINT_BUG_PROMPT_AND_ABORT(); return *(user_op::TensorDesc*)nullptr; } diff --git a/oneflow/core/kernel/user_kernel.cpp b/oneflow/core/kernel/user_kernel.cpp index 9668ef2b0f1..7a29d0fa53d 100644 --- a/oneflow/core/kernel/user_kernel.cpp +++ b/oneflow/core/kernel/user_kernel.cpp @@ -257,10 +257,7 @@ class UserKernelOpInferContext : public user_op::InferContext { const user_op::TensorDesc& TensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) const { auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index)); - if (it == arg2tensor_desc_.end()) { - PRINT_BUG_PROMPT_AND_ABORT(); - return *(user_op::TensorDesc*)nullptr; - } + if (it == arg2tensor_desc_.end()) { return *(user_op::TensorDesc*)nullptr; } return *it->second; } user_op::TensorDesc* MutTensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) { diff --git a/oneflow/core/operator/user_op.cpp b/oneflow/core/operator/user_op.cpp index 4ed1a75a096..913920b0180 100644 --- a/oneflow/core/operator/user_op.cpp +++ b/oneflow/core/operator/user_op.cpp @@ -155,10 +155,7 @@ class UserOpInferContext final : public user_op::InferContext { const user_op::TensorDesc& TensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) const { auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index)); - if (it == arg2tensor_desc_.end()) { - PRINT_BUG_PROMPT_AND_ABORT(); - return *(user_op::TensorDesc*)nullptr; - }; + if (it == arg2tensor_desc_.end()) { return *(user_op::TensorDesc*)nullptr; }; return it->second; } user_op::TensorDesc* MutTensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) { diff --git a/oneflow/user/kernels/stateful_opkernel.cpp b/oneflow/user/kernels/stateful_opkernel.cpp index abd3caa104a..f40160a345b 100644 --- a/oneflow/user/kernels/stateful_opkernel.cpp +++ b/oneflow/user/kernels/stateful_opkernel.cpp @@ -60,7 +60,6 @@ class ZeroCopyBaseContextHelper { inedx = TryGetTensorTupleIndex(output_arg_tuple_->arg_name2bn_index2tensor_tuple_index(), arg_name, index); if (inedx >= 0) { return *call_ctx->outputs().at(inedx); } - PRINT_BUG_PROMPT_AND_ABORT(); return *(user_op::TensorDesc*)nullptr; } user_op::TensorDesc* MutTensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, @@ -332,6 +331,9 @@ class UserOpInferContextHelper final { const std::string& arg_name, int32_t index) const { const user_op::TensorDesc& tensor_desc = TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index); + if ((&tensor_desc) == nullptr) { + LOG(FATAL) << "Arg (" << arg_name << "," << index << ") is not found"; + } return tensor_desc; } user_op::TensorDesc* MutNonNullTensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, @@ -498,12 +500,6 @@ class UserKernelComputeContextHelper final { return base_ctx_helper_.TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index); } - const user_op::TensorDesc* MutTensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, - const std::string& arg_name, - int32_t index) const { - return base_ctx_helper_.MutTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index); - } - user_op::Tensor* Tensor4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { return base_ctx_helper_.Tensor4ArgNameAndIndex(call_ctx, arg_name, index); @@ -590,11 +586,6 @@ class UserKernelRegContextHelper final { int32_t index) const { return base_ctx_helper_.TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index); } - const user_op::TensorDesc* MutTensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, - const std::string& arg_name, - int32_t index) const { - return base_ctx_helper_.MutTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index); - } const ArgVec& inputs() const { return base_ctx_helper_.inputs(); } const ArgVec& outputs() const { return base_ctx_helper_.outputs(); } @@ -659,11 +650,6 @@ class UserKernelInitAndCacheContextHelper final { int32_t index) const { return base_ctx_helper_.TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index); } - const user_op::TensorDesc* MutTensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, - const std::string& arg_name, - int32_t index) const { - return base_ctx_helper_.MutTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index); - } const user_op::TensorDesc* LogicalTensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { @@ -719,7 +705,7 @@ class UserKernelInitAndCacheContext final : public user_op::KernelInitContext, const ParallelContext& parallel_ctx() const override { return helper_->parallel_ctx(call_ctx_); } const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { - return helper_->MutTensorDesc4ArgNameAndIndex(call_ctx_, arg_name, index); + return &helper_->TensorDesc4ArgNameAndIndex(call_ctx_, arg_name, index); } const user_op::TensorDesc* LogicalTensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { From 863f010b872cd535b84765b131549e91376ef969 Mon Sep 17 00:00:00 2001 From: clackhan Date: Sun, 24 Jul 2022 08:52:37 +0800 Subject: [PATCH 60/67] fix static check error --- oneflow/core/framework/op_expr.cpp | 2 +- oneflow/core/kernel/user_kernel.cpp | 2 +- oneflow/core/operator/user_op.cpp | 2 +- oneflow/user/kernels/stateful_opkernel.cpp | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/oneflow/core/framework/op_expr.cpp b/oneflow/core/framework/op_expr.cpp index c83fcd4c5e2..9bab0e48868 100644 --- a/oneflow/core/framework/op_expr.cpp +++ b/oneflow/core/framework/op_expr.cpp @@ -212,7 +212,7 @@ class UserOpExprInferContext : public user_op::InferContext { return *const_cast(tensor_meta4input_index_(tuple_index)); } } - return *(user_op::TensorDesc*)nullptr; + return *(user_op::TensorDesc*)nullptr; // NOLINT } user_op::TensorDesc* MutTensorDesc4ArgNameAndIndex(const std::string& name, int32_t index) { diff --git a/oneflow/core/kernel/user_kernel.cpp b/oneflow/core/kernel/user_kernel.cpp index 7a29d0fa53d..c6f5199f557 100644 --- a/oneflow/core/kernel/user_kernel.cpp +++ b/oneflow/core/kernel/user_kernel.cpp @@ -257,7 +257,7 @@ class UserKernelOpInferContext : public user_op::InferContext { const user_op::TensorDesc& TensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) const { auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index)); - if (it == arg2tensor_desc_.end()) { return *(user_op::TensorDesc*)nullptr; } + if (it == arg2tensor_desc_.end()) { return *(user_op::TensorDesc*)nullptr; } // NOLINT return *it->second; } user_op::TensorDesc* MutTensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) { diff --git a/oneflow/core/operator/user_op.cpp b/oneflow/core/operator/user_op.cpp index 913920b0180..337d4c8ae51 100644 --- a/oneflow/core/operator/user_op.cpp +++ b/oneflow/core/operator/user_op.cpp @@ -155,7 +155,7 @@ class UserOpInferContext final : public user_op::InferContext { const user_op::TensorDesc& TensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) const { auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index)); - if (it == arg2tensor_desc_.end()) { return *(user_op::TensorDesc*)nullptr; }; + if (it == arg2tensor_desc_.end()) { return *(user_op::TensorDesc*)nullptr; } // NOLINT return it->second; } user_op::TensorDesc* MutTensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) { diff --git a/oneflow/user/kernels/stateful_opkernel.cpp b/oneflow/user/kernels/stateful_opkernel.cpp index f40160a345b..6a47c8a754a 100644 --- a/oneflow/user/kernels/stateful_opkernel.cpp +++ b/oneflow/user/kernels/stateful_opkernel.cpp @@ -60,7 +60,7 @@ class ZeroCopyBaseContextHelper { inedx = TryGetTensorTupleIndex(output_arg_tuple_->arg_name2bn_index2tensor_tuple_index(), arg_name, index); if (inedx >= 0) { return *call_ctx->outputs().at(inedx); } - return *(user_op::TensorDesc*)nullptr; + return *(user_op::TensorDesc*)nullptr; // NOLINT } user_op::TensorDesc* MutTensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, From d3a8a22e8e7ac22f97ef44fbf610f05f42ace325 Mon Sep 17 00:00:00 2001 From: clackhan Date: Sun, 24 Jul 2022 09:24:55 +0800 Subject: [PATCH 61/67] refine --- oneflow/core/framework/op_expr.cpp | 1 + oneflow/core/kernel/user_kernel.cpp | 5 ++++- oneflow/core/operator/user_op.cpp | 5 ++++- 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/oneflow/core/framework/op_expr.cpp b/oneflow/core/framework/op_expr.cpp index 9bab0e48868..9123dad322c 100644 --- a/oneflow/core/framework/op_expr.cpp +++ b/oneflow/core/framework/op_expr.cpp @@ -212,6 +212,7 @@ class UserOpExprInferContext : public user_op::InferContext { return *const_cast(tensor_meta4input_index_(tuple_index)); } } + PRINT_BUG_PROMPT_AND_ABORT(); return *(user_op::TensorDesc*)nullptr; // NOLINT } diff --git a/oneflow/core/kernel/user_kernel.cpp b/oneflow/core/kernel/user_kernel.cpp index c6f5199f557..9b004775e39 100644 --- a/oneflow/core/kernel/user_kernel.cpp +++ b/oneflow/core/kernel/user_kernel.cpp @@ -257,7 +257,10 @@ class UserKernelOpInferContext : public user_op::InferContext { const user_op::TensorDesc& TensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) const { auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index)); - if (it == arg2tensor_desc_.end()) { return *(user_op::TensorDesc*)nullptr; } // NOLINT + if (it == arg2tensor_desc_.end()) { + PRINT_BUG_PROMPT_AND_ABORT(); + return *(user_op::TensorDesc*)nullptr; // NOLINT + } return *it->second; } user_op::TensorDesc* MutTensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) { diff --git a/oneflow/core/operator/user_op.cpp b/oneflow/core/operator/user_op.cpp index 337d4c8ae51..215cb9b5e84 100644 --- a/oneflow/core/operator/user_op.cpp +++ b/oneflow/core/operator/user_op.cpp @@ -155,7 +155,10 @@ class UserOpInferContext final : public user_op::InferContext { const user_op::TensorDesc& TensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) const { auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index)); - if (it == arg2tensor_desc_.end()) { return *(user_op::TensorDesc*)nullptr; } // NOLINT + if (it == arg2tensor_desc_.end()) { + PRINT_BUG_PROMPT_AND_ABORT(); + return *(user_op::TensorDesc*)nullptr; // NOLINT + } return it->second; } user_op::TensorDesc* MutTensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) { From c8c333ffe448f5c50b5194747e9cf17adea820cf Mon Sep 17 00:00:00 2001 From: clackhan Date: Sun, 24 Jul 2022 11:55:02 +0800 Subject: [PATCH 62/67] refine --- oneflow/core/framework/op_expr.cpp | 33 +++++++++---------- oneflow/core/kernel/user_kernel.cpp | 19 +++++------ oneflow/core/operator/user_op.cpp | 8 ++--- oneflow/user/kernels/stateful_opkernel.cpp | 37 +++++++++------------- 4 files changed, 42 insertions(+), 55 deletions(-) diff --git a/oneflow/core/framework/op_expr.cpp b/oneflow/core/framework/op_expr.cpp index 9123dad322c..2a1a3ef355b 100644 --- a/oneflow/core/framework/op_expr.cpp +++ b/oneflow/core/framework/op_expr.cpp @@ -191,29 +191,26 @@ class UserOpExprInferContext : public user_op::InferContext { const user_op::TensorDesc& InputTensorDesc(const std::string& arg_name, int32_t index) const override { - return TensorDesc4ArgNameAndIndex(arg_name, index); + return *TensorDesc4ArgNameAndIndex(arg_name, index); } user_op::TensorDesc* OutputTensorDesc(const std::string& name, int32_t index) override { return MutTensorDesc4ArgNameAndIndex(name, index); } - const user_op::TensorDesc& TensorDesc4ArgNameAndIndex(const std::string& name, + const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& name, int32_t index) const { { const auto& arg_tuple = *user_op_expr_->output_arg_tuple(); int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); - if (tuple_index >= 0) { return *tensor_meta4output_index_(tuple_index); } + if (tuple_index >= 0) { return tensor_meta4output_index_(tuple_index); } } { const auto& arg_tuple = *user_op_expr_->input_arg_tuple(); int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); - if (tuple_index >= 0) { - return *const_cast(tensor_meta4input_index_(tuple_index)); - } + if (tuple_index >= 0) { return tensor_meta4input_index_(tuple_index); } } - PRINT_BUG_PROMPT_AND_ABORT(); - return *(user_op::TensorDesc*)nullptr; // NOLINT + return nullptr; } user_op::TensorDesc* MutTensorDesc4ArgNameAndIndex(const std::string& name, int32_t index) { @@ -254,7 +251,7 @@ class UserOpExprInferContext : public user_op::InferContext { } const Shape& Shape4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { - return TensorDesc4ArgNameAndIndex(arg_name, index).shape(); + return TensorDesc4ArgNameAndIndex(arg_name, index)->shape(); } Shape* MutShape4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { @@ -283,7 +280,7 @@ class UserOpExprInferContext : public user_op::InferContext { } const Stride& Stride4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { - return TensorDesc4ArgNameAndIndex(arg_name, index).stride(); + return TensorDesc4ArgNameAndIndex(arg_name, index)->stride(); } Stride* MutStride4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { @@ -300,7 +297,7 @@ class UserOpExprInferContext : public user_op::InferContext { return MutDtype4ArgNameAndIndex(arg_name, index); } DataType Dtype4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { - return TensorDesc4ArgNameAndIndex(arg_name, index).data_type(); + return TensorDesc4ArgNameAndIndex(arg_name, index)->data_type(); } DataType* MutDtype4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { return MutTensorDesc4ArgNameAndIndex(arg_name, index)->mut_data_type(); @@ -315,7 +312,7 @@ class UserOpExprInferContext : public user_op::InferContext { return MutIsDynamic4ArgNameAndIndex(arg_name, index); } bool IsDynamic4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { - return TensorDesc4ArgNameAndIndex(arg_name, index).is_dynamic(); + return TensorDesc4ArgNameAndIndex(arg_name, index)->is_dynamic(); } bool* MutIsDynamic4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { return MutTensorDesc4ArgNameAndIndex(arg_name, index)->mut_is_dynamic(); @@ -424,16 +421,16 @@ class UserOpExprLogicalInferContext final : public UserOpExprInferContext { const ParallelDesc& parallel_desc() const override { return *parallel_desc_; } const SbpParallel& SbpParallel4ArgNameAndIndex(const std::string& name, int32_t index) const override { - const GlobalTensorMeta& tensor_meta = - dynamic_cast(TensorDesc4ArgNameAndIndex(name, index)); - Symbol nd_sbp = tensor_meta.nd_sbp(); + const GlobalTensorMeta* tensor_meta = + dynamic_cast(TensorDesc4ArgNameAndIndex(name, index)); + Symbol nd_sbp = tensor_meta->nd_sbp(); CHECK_EQ(nd_sbp->sbp_parallel_size(), 1); return nd_sbp->sbp_parallel(0); } const NdSbp& NdSbp4ArgNameAndIndex(const std::string& name, int32_t index) const override { - const GlobalTensorMeta& tensor_meta = - dynamic_cast(TensorDesc4ArgNameAndIndex(name, index)); - return *tensor_meta.nd_sbp(); + const GlobalTensorMeta* tensor_meta = + dynamic_cast(TensorDesc4ArgNameAndIndex(name, index)); + return *tensor_meta->nd_sbp(); } int64_t parallel_num() const override { return parallel_desc_->parallel_num(); } diff --git a/oneflow/core/kernel/user_kernel.cpp b/oneflow/core/kernel/user_kernel.cpp index 9b004775e39..af13b75d2dc 100644 --- a/oneflow/core/kernel/user_kernel.cpp +++ b/oneflow/core/kernel/user_kernel.cpp @@ -249,19 +249,16 @@ class UserKernelOpInferContext : public user_op::InferContext { const user_op::TensorDesc& InputTensorDesc(const std::string& arg_name, int32_t index) const override { - return TensorDesc4ArgNameAndIndex(arg_name, index); + return *TensorDesc4ArgNameAndIndex(arg_name, index); } user_op::TensorDesc* OutputTensorDesc(const std::string& arg_name, int32_t index) override { return MutTensorDesc4ArgNameAndIndex(arg_name, index); } - const user_op::TensorDesc& TensorDesc4ArgNameAndIndex(const std::string& arg_name, + const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) const { auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index)); - if (it == arg2tensor_desc_.end()) { - PRINT_BUG_PROMPT_AND_ABORT(); - return *(user_op::TensorDesc*)nullptr; // NOLINT - } - return *it->second; + if (it == arg2tensor_desc_.end()) { return nullptr; } + return it->second.get(); } user_op::TensorDesc* MutTensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) { auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index)); @@ -278,7 +275,7 @@ class UserKernelOpInferContext : public user_op::InferContext { return MutShape4ArgNameAndIndex(arg_name, index); } const Shape& Shape4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { - return TensorDesc4ArgNameAndIndex(arg_name, index).shape(); + return TensorDesc4ArgNameAndIndex(arg_name, index)->shape(); } Shape* MutShape4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { return MutTensorDesc4ArgNameAndIndex(arg_name, index)->mut_shape(); @@ -293,7 +290,7 @@ class UserKernelOpInferContext : public user_op::InferContext { return MutStride4ArgNameAndIndex(arg_name, index); } const Stride& Stride4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { - return TensorDesc4ArgNameAndIndex(arg_name, index).stride(); + return TensorDesc4ArgNameAndIndex(arg_name, index)->stride(); } Stride* MutStride4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { return MutTensorDesc4ArgNameAndIndex(arg_name, index)->mut_stride(); @@ -308,7 +305,7 @@ class UserKernelOpInferContext : public user_op::InferContext { return MutDtype4ArgNameAndIndex(arg_name, index); } DataType Dtype4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { - return TensorDesc4ArgNameAndIndex(arg_name, index).data_type(); + return TensorDesc4ArgNameAndIndex(arg_name, index)->data_type(); } DataType* MutDtype4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { return MutTensorDesc4ArgNameAndIndex(arg_name, index)->mut_data_type(); @@ -323,7 +320,7 @@ class UserKernelOpInferContext : public user_op::InferContext { return MutIsDynamic4ArgNameAndIndex(arg_name, index); } bool IsDynamic4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { - return TensorDesc4ArgNameAndIndex(arg_name, index).is_dynamic(); + return TensorDesc4ArgNameAndIndex(arg_name, index)->is_dynamic(); } bool* MutIsDynamic4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { return MutTensorDesc4ArgNameAndIndex(arg_name, index)->mut_is_dynamic(); diff --git a/oneflow/core/operator/user_op.cpp b/oneflow/core/operator/user_op.cpp index 215cb9b5e84..b87edc8c363 100644 --- a/oneflow/core/operator/user_op.cpp +++ b/oneflow/core/operator/user_op.cpp @@ -147,19 +147,19 @@ class UserOpInferContext final : public user_op::InferContext { const user_op::TensorDesc& InputTensorDesc(const std::string& arg_name, int32_t index) const override { - return TensorDesc4ArgNameAndIndex(arg_name, index); + return *TensorDesc4ArgNameAndIndex(arg_name, index); } user_op::TensorDesc* OutputTensorDesc(const std::string& arg_name, int32_t index) override { return MutTensorDesc4ArgNameAndIndex(arg_name, index); } - const user_op::TensorDesc& TensorDesc4ArgNameAndIndex(const std::string& arg_name, + const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) const { auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index)); if (it == arg2tensor_desc_.end()) { PRINT_BUG_PROMPT_AND_ABORT(); - return *(user_op::TensorDesc*)nullptr; // NOLINT + return nullptr; } - return it->second; + return &it->second; } user_op::TensorDesc* MutTensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) { auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index)); diff --git a/oneflow/user/kernels/stateful_opkernel.cpp b/oneflow/user/kernels/stateful_opkernel.cpp index 6a47c8a754a..830edaa0f71 100644 --- a/oneflow/user/kernels/stateful_opkernel.cpp +++ b/oneflow/user/kernels/stateful_opkernel.cpp @@ -51,16 +51,11 @@ class ZeroCopyBaseContextHelper { index); \ if (i >= 0) { return (outputs).at(i) post_action; } - const user_op::TensorDesc& TensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, + const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, const int32_t index) const { - int32_t inedx = TryGetTensorTupleIndex(input_arg_tuple_->arg_name2bn_index2tensor_tuple_index(), - arg_name, index); - if (inedx >= 0) { return *call_ctx->inputs().at(inedx); } - inedx = TryGetTensorTupleIndex(output_arg_tuple_->arg_name2bn_index2tensor_tuple_index(), - arg_name, index); - if (inedx >= 0) { return *call_ctx->outputs().at(inedx); } - return *(user_op::TensorDesc*)nullptr; // NOLINT + RETURN_IF_FOUND(call_ctx->inputs(), call_ctx->outputs(), .get()); + return nullptr; } user_op::TensorDesc* MutTensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, @@ -170,14 +165,14 @@ class UserOpInferContextHelper final { const user_op::TensorDesc& InputTensorDesc(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { - return TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index); + return *TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index); } user_op::TensorDesc* OutputTensorDesc(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { return MutTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index); } - const user_op::TensorDesc& TensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, + const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { return zero_copy_base_ctx_helper_.TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index); @@ -330,11 +325,9 @@ class UserOpInferContextHelper final { const user_op::TensorDesc& NonNullTensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { - const user_op::TensorDesc& tensor_desc = TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index); - if ((&tensor_desc) == nullptr) { - LOG(FATAL) << "Arg (" << arg_name << "," << index << ") is not found"; - } - return tensor_desc; + const user_op::TensorDesc* tensor_desc = TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index); + if (!tensor_desc) { LOG(FATAL) << "Arg (" << arg_name << "," << index << ") is not found"; } + return *tensor_desc; } user_op::TensorDesc* MutNonNullTensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, @@ -367,7 +360,7 @@ class UserOpInferContext : public user_op::InferContext { user_op::TensorDesc* OutputTensorDesc(const std::string& arg_name, int32_t index) override { return helper_->OutputTensorDesc(call_ctx_, arg_name, index); } - const user_op::TensorDesc& TensorDesc4ArgNameAndIndex(const std::string& arg_name, + const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) const { return helper_->TensorDesc4ArgNameAndIndex(call_ctx_, arg_name, index); } @@ -494,7 +487,7 @@ class UserKernelComputeContextHelper final { ~UserKernelComputeContextHelper() = default; - const user_op::TensorDesc& TensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, + const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { return base_ctx_helper_.TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index); @@ -534,7 +527,7 @@ class UserKernelComputeContext final : public user_op::KernelComputeContext { const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { - return &helper_->TensorDesc4ArgNameAndIndex(call_ctx_, arg_name, index); + return helper_->TensorDesc4ArgNameAndIndex(call_ctx_, arg_name, index); } user_op::Tensor* Tensor4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { @@ -581,7 +574,7 @@ class UserKernelRegContextHelper final { const ParallelContext& parallel_ctx(eager::CallContext* call_ctx) const { return base_ctx_helper_.parallel_ctx(call_ctx); } - const user_op::TensorDesc& TensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, + const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { return base_ctx_helper_.TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index); @@ -611,7 +604,7 @@ class UserKernelRegContext final : public user_op::KernelRegContext { const ParallelContext& parallel_ctx() const override { return helper_->parallel_ctx(call_ctx_); } const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { - return &helper_->TensorDesc4ArgNameAndIndex(call_ctx_, arg_name, index); + return helper_->TensorDesc4ArgNameAndIndex(call_ctx_, arg_name, index); } const ArgVec& inputs() const override { return helper_->inputs(); } const ArgVec& outputs() const override { return helper_->outputs(); } @@ -645,7 +638,7 @@ class UserKernelInitAndCacheContextHelper final { const ParallelContext& parallel_ctx(eager::CallContext* call_ctx) const { return base_ctx_helper_.parallel_ctx(call_ctx); } - const user_op::TensorDesc& TensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, + const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index) const { return base_ctx_helper_.TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index); @@ -705,7 +698,7 @@ class UserKernelInitAndCacheContext final : public user_op::KernelInitContext, const ParallelContext& parallel_ctx() const override { return helper_->parallel_ctx(call_ctx_); } const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { - return &helper_->TensorDesc4ArgNameAndIndex(call_ctx_, arg_name, index); + return helper_->TensorDesc4ArgNameAndIndex(call_ctx_, arg_name, index); } const user_op::TensorDesc* LogicalTensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { From 5d760ff070b738ef70007889cd3c1c989ea05ab0 Mon Sep 17 00:00:00 2001 From: clackhan Date: Mon, 25 Jul 2022 11:46:45 +0800 Subject: [PATCH 63/67] reslove comment --- oneflow/core/common/tensor_meta.h | 15 ++----- oneflow/core/eager/eager_blob_object.cpp | 54 ++++++++++++------------ oneflow/core/eager/eager_blob_object.h | 18 ++++---- 3 files changed, 40 insertions(+), 47 deletions(-) diff --git a/oneflow/core/common/tensor_meta.h b/oneflow/core/common/tensor_meta.h index b846bcc74bb..c84fef5a02f 100644 --- a/oneflow/core/common/tensor_meta.h +++ b/oneflow/core/common/tensor_meta.h @@ -78,13 +78,6 @@ class TensorMeta : public user_op::TensorDesc { } virtual void set_is_dynamic(bool val) override { PRINT_BUG_PROMPT_AND_ABORT(); } - virtual void set_shape(const std::shared_ptr& val) { PRINT_BUG_PROMPT_AND_ABORT(); } - virtual void set_stride(const std::shared_ptr& val) { - PRINT_BUG_PROMPT_AND_ABORT(); - } - virtual DataType* mut_dtype() { PRINT_BUG_PROMPT_AND_ABORT(); } - virtual void set_dtype(DataType data_type) { PRINT_BUG_PROMPT_AND_ABORT(); } - protected: TensorMeta& operator=(const TensorMeta& other) { this->shape_ = std::make_shared(*other.shape_); @@ -116,10 +109,10 @@ class MutTensorMeta : public TensorMeta { bool* mut_is_dynamic() override { return &is_dynamic_; } void set_is_dynamic(bool val) override { is_dynamic_ = val; } - void set_shape(const std::shared_ptr& val) override { shape_ = val; } - void set_stride(const std::shared_ptr& val) override { stride_ = val; } - DataType* mut_dtype() override { return &data_type_; } - void set_dtype(DataType data_type) override { data_type_ = data_type; } + void set_shape(const std::shared_ptr& val) { shape_ = val; } + void set_stride(const std::shared_ptr& val) { stride_ = val; } + DataType* mut_dtype() { return &data_type_; } + void set_dtype(DataType data_type) { data_type_ = data_type; } bool operator==(const MutTensorMeta& other) const; size_t CalcHashValue() const; diff --git a/oneflow/core/eager/eager_blob_object.cpp b/oneflow/core/eager/eager_blob_object.cpp index a83d5653d38..b9bf6f9d895 100644 --- a/oneflow/core/eager/eager_blob_object.cpp +++ b/oneflow/core/eager/eager_blob_object.cpp @@ -26,9 +26,9 @@ namespace vm { EagerBlobObject::EagerBlobObject( const std::shared_ptr& mem_case, - const Symbol& local_tensor_meta, - const std::shared_ptr& mut_local_tensor_meta, DataType data_type, - const std::shared_ptr& tensor_storage, + const Symbol& static_local_tensor_meta, + const std::shared_ptr& dynamic_local_tensor_meta, + DataType data_type, const std::shared_ptr& tensor_storage, const intrusive::shared_ptr& dep_object) : is_dynamic_(false), mem_case_(mem_case), @@ -40,54 +40,54 @@ EagerBlobObject::EagerBlobObject( is_non_pod_object_placement_newed_(false), pin_memory_(false), compute_local_dep_object_(dep_object), - blob_desc_(static_cast(mut_local_tensor_meta) - ? std::const_pointer_cast(mut_local_tensor_meta->shape_ptr()) - : std::const_pointer_cast(local_tensor_meta->shape_ptr()), - static_cast(mut_local_tensor_meta) - ? std::const_pointer_cast(mut_local_tensor_meta->stride_ptr()) - : std::const_pointer_cast(local_tensor_meta->stride_ptr()), + blob_desc_(static_cast(dynamic_local_tensor_meta) + ? std::const_pointer_cast(dynamic_local_tensor_meta->shape_ptr()) + : std::const_pointer_cast(static_local_tensor_meta->shape_ptr()), + static_cast(dynamic_local_tensor_meta) + ? std::const_pointer_cast(dynamic_local_tensor_meta->stride_ptr()) + : std::const_pointer_cast(static_local_tensor_meta->stride_ptr()), data_type), - local_tensor_meta_(local_tensor_meta), - mut_local_tensor_meta_(mut_local_tensor_meta) { + static_local_tensor_meta_(static_local_tensor_meta), + dynamic_local_tensor_meta_(dynamic_local_tensor_meta) { CHECK(static_cast(tensor_storage)); } // user_op::TensorDesc overrides const Shape& EagerBlobObject::shape() const { - if (mut_local_tensor_meta_) { - return mut_local_tensor_meta_->shape(); + if (dynamic_local_tensor_meta_) { + return dynamic_local_tensor_meta_->shape(); } else { - return local_tensor_meta_->shape(); + return static_local_tensor_meta_->shape(); } } Shape* EagerBlobObject::mut_shape() { - CHECK(mut_local_tensor_meta_); - return std::const_pointer_cast(mut_local_tensor_meta_)->mut_shape(); + CHECK(dynamic_local_tensor_meta_); + return std::const_pointer_cast(dynamic_local_tensor_meta_)->mut_shape(); } const Stride& EagerBlobObject::stride() const { - if (mut_local_tensor_meta_) { - return mut_local_tensor_meta_->stride(); + if (dynamic_local_tensor_meta_) { + return dynamic_local_tensor_meta_->stride(); } else { - return local_tensor_meta_->stride(); + return static_local_tensor_meta_->stride(); } } Stride* EagerBlobObject::mut_stride() { - CHECK(mut_local_tensor_meta_); - return std::const_pointer_cast(mut_local_tensor_meta_)->mut_stride(); + CHECK(dynamic_local_tensor_meta_); + return std::const_pointer_cast(dynamic_local_tensor_meta_)->mut_stride(); } std::shared_ptr EagerBlobObject::shape_ptr() const { - if (mut_local_tensor_meta_) { - return mut_local_tensor_meta_->shape_ptr(); + if (dynamic_local_tensor_meta_) { + return dynamic_local_tensor_meta_->shape_ptr(); } else { - return local_tensor_meta_->shape_ptr(); + return static_local_tensor_meta_->shape_ptr(); } } std::shared_ptr EagerBlobObject::stride_ptr() const { - if (mut_local_tensor_meta_) { - return mut_local_tensor_meta_->stride_ptr(); + if (dynamic_local_tensor_meta_) { + return dynamic_local_tensor_meta_->stride_ptr(); } else { - return local_tensor_meta_->stride_ptr(); + return static_local_tensor_meta_->stride_ptr(); } } diff --git a/oneflow/core/eager/eager_blob_object.h b/oneflow/core/eager/eager_blob_object.h index 987dfc337ab..91939304bbc 100644 --- a/oneflow/core/eager/eager_blob_object.h +++ b/oneflow/core/eager/eager_blob_object.h @@ -99,24 +99,24 @@ class EagerBlobObject final : public user_op::Tensor, EagerBlobObject(const EagerBlobObject&) = delete; EagerBlobObject(EagerBlobObject&&) = delete; EagerBlobObject(const std::shared_ptr& mem_case, - const Symbol& local_tensor_meta, - const std::shared_ptr& mut_local_tensor_meta, + const Symbol& static_local_tensor_meta, + const std::shared_ptr& dynamic_local_tensor_meta, DataType data_type, const std::shared_ptr& tensor_storage) - : EagerBlobObject(mem_case, local_tensor_meta, mut_local_tensor_meta, data_type, + : EagerBlobObject(mem_case, static_local_tensor_meta, dynamic_local_tensor_meta, data_type, tensor_storage, intrusive::shared_ptr()) {} EagerBlobObject(const std::shared_ptr& mem_case, - const Symbol& local_tensor_meta, - const std::shared_ptr& mut_local_tensor_meta, + const Symbol& static_local_tensor_meta, + const std::shared_ptr& dynamic_local_tensor_meta, DataType data_type, const std::shared_ptr& tensor_storage, const intrusive::shared_ptr& dep_object); ~EagerBlobObject() { tensor_storage_.reset(); } const std::shared_ptr& mut_tensor_meta() { - return mut_local_tensor_meta_; + return dynamic_local_tensor_meta_; } // Getters - const Symbol& tensor_meta() const { return local_tensor_meta_; } + const Symbol& tensor_meta() const { return static_local_tensor_meta_; } // user_op::TensorDesc overrides const Shape& shape() const override; @@ -233,8 +233,8 @@ class EagerBlobObject final : public user_op::Tensor, // NOTE: Will be removed soon. Avoid to use it whenever possible. BlobDesc blob_desc_; std::unique_ptr blob_; - Symbol local_tensor_meta_; - std::shared_ptr mut_local_tensor_meta_; + Symbol static_local_tensor_meta_; + std::shared_ptr dynamic_local_tensor_meta_; }; using EagerBlobObjectList = small_vector, kOpArgsReservedSize>; From 53e21e41d6568bbc728b7f05f1945a270690e7d3 Mon Sep 17 00:00:00 2001 From: clackhan Date: Mon, 25 Jul 2022 12:19:10 +0800 Subject: [PATCH 64/67] refine --- oneflow/core/common/tensor_meta.h | 4 ---- 1 file changed, 4 deletions(-) diff --git a/oneflow/core/common/tensor_meta.h b/oneflow/core/common/tensor_meta.h index c84fef5a02f..5f71758eecb 100644 --- a/oneflow/core/common/tensor_meta.h +++ b/oneflow/core/common/tensor_meta.h @@ -186,10 +186,6 @@ class GlobalTensorMeta : public TensorMeta { Symbol nd_sbp() const { return nd_sbp_; } Symbol parallel_desc() const { return parallel_desc_; } - void set_nd_sbp(Symbol val) { nd_sbp_ = val; } - - void set_parallel_desc(Symbol val) { parallel_desc_ = val; } - size_t CalcHashValue() const; private: From 8f91ce7f2b9940b206eefe05e2ffdbdc66c0c307 Mon Sep 17 00:00:00 2001 From: binbinHan Date: Mon, 25 Jul 2022 14:25:23 +0800 Subject: [PATCH 65/67] fix typo Co-authored-by: Houjiang Chen --- oneflow/user/kernels/stateful_opkernel.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/oneflow/user/kernels/stateful_opkernel.h b/oneflow/user/kernels/stateful_opkernel.h index 607a050fea3..08c78316119 100644 --- a/oneflow/user/kernels/stateful_opkernel.h +++ b/oneflow/user/kernels/stateful_opkernel.h @@ -130,7 +130,7 @@ class StatefulOpKernel final { std::vector input_tuple_indexes4mut_ibns_; std::vector output_tuple_indexes4mut_obns_; std::vector output_tuple_indexes4mut2_obns_; - HashMap output_tuple_indexe2is_mut2_type_; + HashMap output_tuple_indexes2is_mut2_type_; }; } // namespace one From 44d3728998885b05dbdfc7e4dc65904191afe5f7 Mon Sep 17 00:00:00 2001 From: clackhan Date: Mon, 25 Jul 2022 14:30:08 +0800 Subject: [PATCH 66/67] fxi typo --- oneflow/user/kernels/stateful_opkernel.cpp | 8 ++++---- oneflow/user/kernels/stateful_opkernel.h | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/oneflow/user/kernels/stateful_opkernel.cpp b/oneflow/user/kernels/stateful_opkernel.cpp index 164a51e2fba..276f04ce3cf 100644 --- a/oneflow/user/kernels/stateful_opkernel.cpp +++ b/oneflow/user/kernels/stateful_opkernel.cpp @@ -748,7 +748,7 @@ Maybe InitTensorTupleIndexes4Bns(const std::shared_ptr std::vector* input_tuple_indexes4mut_ibns, std::vector* output_tuple_indexes4mut_obns, std::vector* output_tuple_indexes4mut2_obns, - HashMap* output_tuple_indexe2is_mut2_type) { + HashMap* output_tuple_indexes2is_mut2_type) { const auto* op_reg_val = user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(op_conf->user_conf().op_type_name()); CHECK_NOTNULL_OR_RETURN(op_reg_val); @@ -801,10 +801,10 @@ Maybe InitTensorTupleIndexes4Bns(const std::shared_ptr const std::string obn = GenRepeatedBn(pair.first, pair.second); if (arg_modifier_signature.obn2output_blob_modifier().at(obn).header_infered_before_compute()) { output_tuple_indexes4mut_obns->emplace_back(i); - output_tuple_indexe2is_mut2_type->emplace(i, false); + output_tuple_indexes2is_mut2_type->emplace(i, false); } else { output_tuple_indexes4mut2_obns->emplace_back(i); - output_tuple_indexe2is_mut2_type->emplace(i, true); + output_tuple_indexes2is_mut2_type->emplace(i, true); } } return Maybe::Ok(); @@ -851,7 +851,7 @@ Maybe InitTensorTupleIndexes4Bns(const std::shared_ptr op_conf, input_arg_tuple->indexed_arg_name_and_index(), output_arg_tuple->indexed_arg_name_and_index(), &opkernel->input_tuple_indexes4const_ibns_, &opkernel->input_tuple_indexes4mut_ibns_, &opkernel->output_tuple_indexes4mut_obns_, - &opkernel->output_tuple_indexes4mut2_obns_, &opkernel->output_tuple_indexe2is_mut2_type_)); + &opkernel->output_tuple_indexes4mut2_obns_, &opkernel->output_tuple_indexes2is_mut2_type_)); return opkernel; } diff --git a/oneflow/user/kernels/stateful_opkernel.h b/oneflow/user/kernels/stateful_opkernel.h index 607a050fea3..6d62f6260a5 100644 --- a/oneflow/user/kernels/stateful_opkernel.h +++ b/oneflow/user/kernels/stateful_opkernel.h @@ -72,7 +72,7 @@ class StatefulOpKernel final { } bool output_is_mut2_type(int64_t index) const { - return output_tuple_indexe2is_mut2_type_.at(index); + return output_tuple_indexes2is_mut2_type_.at(index); } const AttrMap& base_attrs() const { return base_attrs_; } @@ -130,7 +130,7 @@ class StatefulOpKernel final { std::vector input_tuple_indexes4mut_ibns_; std::vector output_tuple_indexes4mut_obns_; std::vector output_tuple_indexes4mut2_obns_; - HashMap output_tuple_indexe2is_mut2_type_; + HashMap output_tuple_indexes2is_mut2_type_; }; } // namespace one From b2f96961b1de18307890911f3c4ae5d15223e25a Mon Sep 17 00:00:00 2001 From: clackhan Date: Mon, 25 Jul 2022 15:13:29 +0800 Subject: [PATCH 67/67] use OpArgsVector --- oneflow/core/common/op_args_vector.h | 29 +++++++++++++++++++ .../core/framework/local_tensor_infer_cache.h | 6 +--- oneflow/user/kernels/stateful_opkernel.cpp | 19 ++++++------ oneflow/user/kernels/stateful_opkernel.h | 19 ++++++------ 4 files changed, 49 insertions(+), 24 deletions(-) create mode 100644 oneflow/core/common/op_args_vector.h diff --git a/oneflow/core/common/op_args_vector.h b/oneflow/core/common/op_args_vector.h new file mode 100644 index 00000000000..8aacdf19fdc --- /dev/null +++ b/oneflow/core/common/op_args_vector.h @@ -0,0 +1,29 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_CORE_COMMON_OP_ARGS_VECTOR_H_ +#define ONEFLOW_CORE_COMMON_OP_ARGS_VECTOR_H_ + +#include "oneflow/core/common/small_vector.h" +#include "oneflow/core/common/op_args_reserved_size.h" + +namespace oneflow { + +template +using OpArgsVector = small_vector; + +} + +#endif // ONEFLOW_CORE_COMMON_OP_ARGS_VECTOR_H_ diff --git a/oneflow/core/framework/local_tensor_infer_cache.h b/oneflow/core/framework/local_tensor_infer_cache.h index 12279b7dcf3..45a0eb6dde9 100644 --- a/oneflow/core/framework/local_tensor_infer_cache.h +++ b/oneflow/core/framework/local_tensor_infer_cache.h @@ -18,8 +18,7 @@ limitations under the License. #include "oneflow/core/common/symbol.h" #include "oneflow/core/common/maybe.h" -#include "oneflow/core/common/small_vector.h" -#include "oneflow/core/common/op_args_reserved_size.h" +#include "oneflow/core/common/op_args_vector.h" #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/framework/stream.h" @@ -31,9 +30,6 @@ class Device; namespace one { -template -using OpArgsVector = small_vector; - class TensorTuple; class UserOpExpr; diff --git a/oneflow/user/kernels/stateful_opkernel.cpp b/oneflow/user/kernels/stateful_opkernel.cpp index 276f04ce3cf..c1ebeb71c1d 100644 --- a/oneflow/user/kernels/stateful_opkernel.cpp +++ b/oneflow/user/kernels/stateful_opkernel.cpp @@ -741,14 +741,13 @@ class UserKernelInitAndCacheContext final : public user_op::KernelInitContext, namespace { -Maybe InitTensorTupleIndexes4Bns(const std::shared_ptr& op_conf, - const ArgVec& indexed_input_pairs, - const ArgVec& indexed_output_pairs, - std::vector* input_tuple_indexes4const_ibns, - std::vector* input_tuple_indexes4mut_ibns, - std::vector* output_tuple_indexes4mut_obns, - std::vector* output_tuple_indexes4mut2_obns, - HashMap* output_tuple_indexes2is_mut2_type) { +Maybe InitTensorTupleIndexes4Bns( + const std::shared_ptr& op_conf, const ArgVec& indexed_input_pairs, + const ArgVec& indexed_output_pairs, OpArgsVector* input_tuple_indexes4const_ibns, + OpArgsVector* input_tuple_indexes4mut_ibns, + OpArgsVector* output_tuple_indexes4mut_obns, + OpArgsVector* output_tuple_indexes4mut2_obns, + small_vector* output_tuple_indexes2is_mut2_type) { const auto* op_reg_val = user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(op_conf->user_conf().op_type_name()); CHECK_NOTNULL_OR_RETURN(op_reg_val); @@ -801,10 +800,10 @@ Maybe InitTensorTupleIndexes4Bns(const std::shared_ptr const std::string obn = GenRepeatedBn(pair.first, pair.second); if (arg_modifier_signature.obn2output_blob_modifier().at(obn).header_infered_before_compute()) { output_tuple_indexes4mut_obns->emplace_back(i); - output_tuple_indexes2is_mut2_type->emplace(i, false); + output_tuple_indexes2is_mut2_type->emplace_back(false); } else { output_tuple_indexes4mut2_obns->emplace_back(i); - output_tuple_indexes2is_mut2_type->emplace(i, true); + output_tuple_indexes2is_mut2_type->emplace_back(true); } } return Maybe::Ok(); diff --git a/oneflow/user/kernels/stateful_opkernel.h b/oneflow/user/kernels/stateful_opkernel.h index 6d62f6260a5..32d1f165f31 100644 --- a/oneflow/user/kernels/stateful_opkernel.h +++ b/oneflow/user/kernels/stateful_opkernel.h @@ -24,6 +24,7 @@ limitations under the License. #include "oneflow/core/framework/user_op_kernel_registry.h" #include "oneflow/core/framework/arg_tuple.h" #include "oneflow/core/framework/op_interpreter.h" +#include "oneflow/core/common/op_args_vector.h" namespace oneflow { @@ -58,16 +59,16 @@ class StatefulOpKernel final { const Symbol& stream() const { return stream_; } const std::shared_ptr& mem_case() const { return stream_->device()->mem_case(); } const std::string& op_type_name() const { return op_conf_->user_conf().op_type_name(); } - const std::vector& input_tuple_indexes4const_ibns() const { + const OpArgsVector& input_tuple_indexes4const_ibns() const { return input_tuple_indexes4const_ibns_; } - const std::vector& input_tuple_indexes4mut_ibns() const { + const OpArgsVector& input_tuple_indexes4mut_ibns() const { return input_tuple_indexes4mut_ibns_; } - const std::vector& output_tuple_indexes4mut_obns() const { + const OpArgsVector& output_tuple_indexes4mut_obns() const { return output_tuple_indexes4mut_obns_; } - const std::vector& output_tuple_indexes4mut2_obns() const { + const OpArgsVector& output_tuple_indexes4mut2_obns() const { return output_tuple_indexes4mut2_obns_; } @@ -126,11 +127,11 @@ class StatefulOpKernel final { HashMap> op_kernel_state_map_; HashMap> op_kernel_cache_map_; HashMap infer_tmp_size_fn_map_; - std::vector input_tuple_indexes4const_ibns_; - std::vector input_tuple_indexes4mut_ibns_; - std::vector output_tuple_indexes4mut_obns_; - std::vector output_tuple_indexes4mut2_obns_; - HashMap output_tuple_indexes2is_mut2_type_; + OpArgsVector input_tuple_indexes4const_ibns_; + OpArgsVector input_tuple_indexes4mut_ibns_; + OpArgsVector output_tuple_indexes4mut_obns_; + OpArgsVector output_tuple_indexes4mut2_obns_; + OpArgsVector output_tuple_indexes2is_mut2_type_; }; } // namespace one