From 9a5e7506c75257e3a317847da5ce9d0519a85e17 Mon Sep 17 00:00:00 2001 From: Li Xinqi Date: Thu, 23 Jun 2022 00:34:41 +0800 Subject: [PATCH] Decouple stream and instruction (#7607) * remove deprecated python api * backup code * backup code * fix compiler complaints * fix typo in refactoring * kMockDevice * add unit test test_mock.py * revert mock kernels * vert DEVICE_TYPE_SEQ * mock placement * address pr comments * register device kCriticalSectionDevice and kLazyJobLauncher * kControlDevice * Stream::vm_stream_ * fix compiler complaints * backup code * rename StreamIsTransport to IsCommNetStream * decouple vm::StreamType and vm::InstructionType * fix compiler complaints * remove 'gpu' related code * address static analyzer complaints * address static analyzer complaints * remove unused module in test_mock.py * the Env is never destroyed. * export Env into python * more unittests * export unittest.TestCase in framework/unittest.py * SwitchToShuttingDownPhase * optional is_normal_exit * VirtualMachine::CloseVMThreads * Delete env_api.h env_api.h is deleted by master * reshape_only_one_dim_infered * address pr comments * rollback flow.env.all_device_placement * no distributed running test_shutting_down.py * auto format by CI * expand lifetime of module oneflow in test_shutting_down.py * refine del depend on of * fix oneflow.placement.__str__ * revert GlobalSync * init_producer_stream in oneflow.from_numpy * debug code for vm * init disable_vm_threads_ in VirtualMachine::VirtualMachine * Update oneflow/core/vm/virtual_machine.h Co-authored-by: daquexian * create stream in forked subprocesses. * refactor StreamRoleSwitch to StreamRoleVisistor * ThreadLocalGuard * auto format by CI * fix compiler complaints * fix static analyzer complaints * VirtualMachine::GetVmStream * fix static analyzer complaints * reimplement AddAndReadVector by std::deque * reimplement AddAndReadVector * merge master * increase atol for test_consistent_rnn_cell.py * StreamRole::AsyncLaunchedCommNet is bound to EventRecordedCudaStreamType * auto format by CI * remove StreamRoleVisitor::VisitInvalid * no copy in AddAndReadVector * fix bug of AddAndReadVector::size_ * disable terminfo to fix missing terminfo symbols Signed-off-by: daquexian * auto format by CI * fix AddAndReadVector::GetGranularity * remove bad unittest * auto format by CI * rename CallInstructionType to OpCallInstructionType * static variable GlobalSingletonPtr is a unique_ptr * replace ++atomic_cnt with atomic_cnt.fetch_add(1, std::memory_order_relaxed) * AddAndReadVector::operator[] * change comments 'lock free' to 'thread safe' * rename StatefulLocalOpKernel to StatefulOpKernel * rename VirtualMachine::vm_ to VirtualMachine::engine_ * mark VirtualMachine::NoMoreErasedInstructions private * mark VirtualMachine::FindOrCreateScheduleLocalDepObject private * remove unused version of VirtualMachineEngine::Receive * rename argname for VirtualMachineEngine::Receive * rename unused PendingInstructionList * rename AddAndReadVector to SteadyVector * optimize SteadyVector::operator[] by __builtin_clzll * refactor SteadyVector::granularity2vector_ to SteadyVector::granularity2data_ * reduce usage of steady_vector::size_ * rename unused anounymous namespace * greater atol for test_consistent_tensordot.py * fix BarrierInstructionType::ComputeInFuseMode * revert container_util.h * run AccessBlobByCallback in default stream of tensor->device * reslove static check * reslove static check * SteadyVector::MutableOrAdd Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> Co-authored-by: chengtbf <472491134@qq.com> Co-authored-by: oneflow-ci-bot Co-authored-by: Xiaoyu Xu Co-authored-by: daquexian Co-authored-by: binbinHan --- oneflow/api/python/functional/tensor_api.cpp | 6 +- oneflow/api/python/vm/id_generator.cpp | 41 --- oneflow/core/boxing/slice_boxing_util.h | 1 + oneflow/core/common/device_type.proto | 2 +- .../singleton_ptr.h} | 32 +- oneflow/core/common/steady_vector.h | 102 ++++++ .../steady_vector_test.cpp} | 22 +- oneflow/core/common/stream_role.h | 60 ++-- oneflow/core/eager/blob_instruction_type.cpp | 2 +- oneflow/core/eager/blob_instruction_type.h | 92 +++++- .../eager/cpu_opkernel_instruction_type.cpp | 36 --- ...pp => critical_section_instruction_type.h} | 20 +- .../critical_section_phy_instr_operand.cpp | 17 +- .../critical_section_phy_instr_operand.h | 37 ++- .../core/eager/cuda_blob_instruction_type.cpp | 59 ---- .../eager/cuda_opkernel_instruction_type.cpp | 74 ----- oneflow/core/eager/eager_blob_object.h | 20 +- ...n_type.cpp => lazy_job_instruction_type.h} | 17 +- .../core/eager/lazy_job_phy_instr_operand.cpp | 27 +- ..._type.cpp => op_call_instruction_type.cpp} | 54 ++-- ...tion_type.h => op_call_instruction_type.h} | 17 +- ...rand.cpp => op_call_phy_instr_operand.cpp} | 25 +- ..._operand.h => op_call_phy_instr_operand.h} | 41 ++- .../release_tensor_arg_phy_instr_operand.h | 5 +- ....cpp => release_tensor_instruction_type.h} | 79 +++-- .../core/framework/instructions_builder.cpp | 221 +++++++------ oneflow/core/framework/instructions_builder.h | 42 +-- oneflow/core/framework/op_expr.cpp | 8 +- oneflow/core/framework/op_expr.h | 6 +- .../eager_consistent_op_interpreter.cpp | 10 +- .../eager_mirrored_op_interpreter.cpp | 7 +- oneflow/core/framework/stream.cpp | 50 ++- oneflow/core/framework/stream.h | 29 +- .../stream_get_call_instruction_name.h | 99 ------ .../stream_get_release_instruction_name.h | 99 ------ .../framework/stream_get_stream_role_name.h | 40 +++ .../framework/stream_is_comm_net_stream.h | 19 +- oneflow/core/framework/stream_mgr.cpp | 61 ++++ oneflow/core/framework/stream_mgr.h | 48 +++ .../core/framework/stream_need_soft_sync.h | 25 +- .../framework/stream_on_independent_thread.h | 37 +++ .../core/framework/tensor_consistent_id.cpp | 1 + oneflow/core/framework/tensor_impl.cpp | 5 +- oneflow/core/vm/barrier_instruction_type.h | 66 ++++ oneflow/core/vm/control_stream_type.cpp | 13 +- oneflow/core/vm/control_stream_type.h | 4 - oneflow/core/vm/cpu_stream_type.cpp | 16 +- oneflow/core/vm/cpu_stream_type.h | 4 - .../critical_section_status_querier.h | 6 +- .../critical_section_stream_type.cpp | 18 +- .../critical_section_stream_type.h | 10 +- oneflow/core/vm/cuda_copy_d2h_stream_type.cpp | 19 +- oneflow/core/vm/cuda_copy_d2h_stream_type.h | 4 - oneflow/core/vm/cuda_copy_h2d_stream_type.cpp | 18 +- oneflow/core/vm/cuda_copy_h2d_stream_type.h | 4 - oneflow/core/vm/cuda_stream_type.cpp | 18 +- oneflow/core/vm/cuda_stream_type.h | 4 - ...pp => event_recorded_cuda_stream_type.cpp} | 36 +-- ...pe.h => event_recorded_cuda_stream_type.h} | 16 +- ...ction_type.cpp => fuse_instruction_type.h} | 32 +- oneflow/core/vm/fuse_phy_instr_operand.h | 9 +- oneflow/core/vm/id_generator.cpp | 44 --- oneflow/core/vm/id_generator.h | 60 ---- oneflow/core/vm/id_util.cpp | 91 ------ oneflow/core/vm/id_util.h | 64 ---- oneflow/core/vm/instr_type_id.h | 81 ----- oneflow/core/vm/instruction.cpp | 59 +--- oneflow/core/vm/instruction.h | 75 ++--- oneflow/core/vm/instruction.proto | 49 --- oneflow/core/vm/instruction_type.cpp | 28 -- oneflow/core/vm/instruction_type.h | 27 +- .../{eager => vm}/lazy_job_device_context.h | 6 +- .../{eager => vm}/lazy_job_stream_type.cpp | 18 +- .../core/{eager => vm}/lazy_job_stream_type.h | 10 +- oneflow/core/vm/runtime_instr_type_id.h | 52 --- .../core/vm/sequential_instruction_type.cpp | 105 ------- oneflow/core/vm/stream.cpp | 35 +-- oneflow/core/vm/stream.h | 48 ++- oneflow/core/vm/stream_desc.cpp | 36 --- oneflow/core/vm/stream_desc.h | 99 ------ oneflow/core/vm/stream_get_stream_type.h | 108 +++++++ oneflow/core/vm/stream_runtime_desc.h | 85 ----- oneflow/core/vm/stream_type.h | 7 - oneflow/core/vm/thread_ctx.cpp | 2 +- oneflow/core/vm/thread_ctx.h | 17 +- oneflow/core/vm/virtual_machine.cpp | 296 ++++++++++++------ oneflow/core/vm/virtual_machine.h | 48 ++- oneflow/core/vm/virtual_machine_engine.cpp | 100 +----- oneflow/core/vm/virtual_machine_engine.h | 46 +-- oneflow/core/vm/virtual_machine_scope.cpp | 2 +- oneflow/core/vm/vm_desc.cpp | 70 ----- oneflow/core/vm/vm_desc.h | 74 ----- oneflow/core/vm/vm_object.h | 3 - oneflow/core/vm/vm_util.cpp | 7 +- ...cal_opkernel.cpp => stateful_opkernel.cpp} | 26 +- ...l_local_opkernel.h => stateful_opkernel.h} | 28 +- python/oneflow/nn/graph/block.py | 2 +- python/oneflow/test/exceptions/test_device.py | 5 +- .../test/modules/test_consistent_tensordot.py | 2 +- .../automated_test_util/profiler.py | 4 +- .../torch_flow_dual_object.py | 2 +- 101 files changed, 1443 insertions(+), 2470 deletions(-) delete mode 100644 oneflow/api/python/vm/id_generator.cpp rename oneflow/core/{eager/cpu_blob_instruction_type.cpp => common/singleton_ptr.h} (55%) create mode 100644 oneflow/core/common/steady_vector.h rename oneflow/core/{vm/stream_runtime_desc.cpp => common/steady_vector_test.cpp} (60%) delete mode 100644 oneflow/core/eager/cpu_opkernel_instruction_type.cpp rename oneflow/core/eager/{critical_section_instruction_type.cpp => critical_section_instruction_type.h} (92%) delete mode 100644 oneflow/core/eager/cuda_blob_instruction_type.cpp delete mode 100644 oneflow/core/eager/cuda_opkernel_instruction_type.cpp rename oneflow/core/eager/{lazy_job_instruction_type.cpp => lazy_job_instruction_type.h} (93%) rename oneflow/core/eager/{opkernel_instruction_type.cpp => op_call_instruction_type.cpp} (77%) rename oneflow/core/eager/{opkernel_instruction_type.h => op_call_instruction_type.h} (70%) rename oneflow/core/eager/{local_call_opkernel_phy_instr_operand.cpp => op_call_phy_instr_operand.cpp} (78%) rename oneflow/core/eager/{local_call_opkernel_phy_instr_operand.h => op_call_phy_instr_operand.h} (78%) rename oneflow/core/eager/{release_tensor_instruction_type.cpp => release_tensor_instruction_type.h} (53%) delete mode 100644 oneflow/core/framework/stream_get_call_instruction_name.h delete mode 100644 oneflow/core/framework/stream_get_release_instruction_name.h create mode 100644 oneflow/core/framework/stream_get_stream_role_name.h create mode 100644 oneflow/core/framework/stream_mgr.cpp create mode 100644 oneflow/core/framework/stream_mgr.h create mode 100644 oneflow/core/framework/stream_on_independent_thread.h create mode 100644 oneflow/core/vm/barrier_instruction_type.h rename oneflow/core/{eager => vm}/critical_section_status_querier.h (91%) rename oneflow/core/{eager => vm}/critical_section_stream_type.cpp (75%) rename oneflow/core/{eager => vm}/critical_section_stream_type.h (80%) rename oneflow/core/vm/{async_cuda_stream_type.cpp => event_recorded_cuda_stream_type.cpp} (60%) rename oneflow/core/vm/{async_cuda_stream_type.h => event_recorded_cuda_stream_type.h} (75%) rename oneflow/core/vm/{fuse_instruction_type.cpp => fuse_instruction_type.h} (58%) delete mode 100644 oneflow/core/vm/id_generator.cpp delete mode 100644 oneflow/core/vm/id_generator.h delete mode 100644 oneflow/core/vm/id_util.cpp delete mode 100644 oneflow/core/vm/id_util.h delete mode 100644 oneflow/core/vm/instr_type_id.h delete mode 100644 oneflow/core/vm/instruction.proto rename oneflow/core/{eager => vm}/lazy_job_device_context.h (93%) rename oneflow/core/{eager => vm}/lazy_job_stream_type.cpp (75%) rename oneflow/core/{eager => vm}/lazy_job_stream_type.h (81%) delete mode 100644 oneflow/core/vm/runtime_instr_type_id.h delete mode 100644 oneflow/core/vm/sequential_instruction_type.cpp delete mode 100644 oneflow/core/vm/stream_desc.cpp delete mode 100644 oneflow/core/vm/stream_desc.h create mode 100644 oneflow/core/vm/stream_get_stream_type.h delete mode 100644 oneflow/core/vm/stream_runtime_desc.h delete mode 100644 oneflow/core/vm/vm_desc.cpp delete mode 100644 oneflow/core/vm/vm_desc.h rename oneflow/user/kernels/{stateful_local_opkernel.cpp => stateful_opkernel.cpp} (96%) rename oneflow/user/kernels/{stateful_local_opkernel.h => stateful_opkernel.h} (95%) diff --git a/oneflow/api/python/functional/tensor_api.cpp b/oneflow/api/python/functional/tensor_api.cpp index b1a867e8ea7..8378daa6157 100644 --- a/oneflow/api/python/functional/tensor_api.cpp +++ b/oneflow/api/python/functional/tensor_api.cpp @@ -287,8 +287,10 @@ class LocalTensorSharedNumpyDataFunctor { // Init blob JUST(tensor_impl->InitEagerBlobObject(NewLocalDepObject(), /*pin_memory=*/false)); - const auto& stream = GetDefaultStreamByDevice(device); - JUST(tensor_impl->eager_blob_object())->set_last_used_stream(stream); + const auto& stream = JUST(GetDefaultStreamByDevice(device)); + const auto& eager_blob_object = JUST(tensor_impl->eager_blob_object()); + JUST(eager_blob_object->init_producer_stream(stream)); + eager_blob_object->set_last_used_stream(stream); std::shared_ptr out(new MirroredTensor(tensor_impl)); return out; } diff --git a/oneflow/api/python/vm/id_generator.cpp b/oneflow/api/python/vm/id_generator.cpp deleted file mode 100644 index 03586b603d6..00000000000 --- a/oneflow/api/python/vm/id_generator.cpp +++ /dev/null @@ -1,41 +0,0 @@ -/* -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -#include -#include "oneflow/api/python/of_api_registry.h" -#include "oneflow/core/vm/id_generator.h" - -namespace oneflow { -namespace vm { - -namespace py = pybind11; - -ONEFLOW_API_PYBIND11_MODULE("vm", m) { - py::class_>(m, "IdGenerator"); - py::class_>( - m, "PhysicalIdGenerator") - .def(py::init<>()) - .def("NewSymbolId", &PhysicalIdGenerator::NewSymbolId) - .def("NewObjectId", &PhysicalIdGenerator::NewSymbolId); - - py::class_>( - m, "LogicalIdGenerator") - .def(py::init<>()) - .def("NewSymbolId", &LogicalIdGenerator::NewSymbolId) - .def("NewObjectId", &LogicalIdGenerator::NewObjectId); -} - -} // namespace vm -} // namespace oneflow diff --git a/oneflow/core/boxing/slice_boxing_util.h b/oneflow/core/boxing/slice_boxing_util.h index 83fe2f619b9..d59cd6f6317 100644 --- a/oneflow/core/boxing/slice_boxing_util.h +++ b/oneflow/core/boxing/slice_boxing_util.h @@ -18,6 +18,7 @@ limitations under the License. #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/placed_nd_sbp.h" +#include "oneflow/core/job/parallel_desc.h" namespace oneflow { diff --git a/oneflow/core/common/device_type.proto b/oneflow/core/common/device_type.proto index bc083768124..2b94416c8cb 100644 --- a/oneflow/core/common/device_type.proto +++ b/oneflow/core/common/device_type.proto @@ -5,5 +5,5 @@ enum DeviceType { kInvalidDevice = 0; kCPU = 1; kCUDA = 2; - kMockDevice = 3; + kMockDevice = 3; // pseudo device for test. } diff --git a/oneflow/core/eager/cpu_blob_instruction_type.cpp b/oneflow/core/common/singleton_ptr.h similarity index 55% rename from oneflow/core/eager/cpu_blob_instruction_type.cpp rename to oneflow/core/common/singleton_ptr.h index b33a1e607c2..eecb0a4cdee 100644 --- a/oneflow/core/eager/cpu_blob_instruction_type.cpp +++ b/oneflow/core/common/singleton_ptr.h @@ -13,21 +13,29 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "oneflow/core/eager/blob_instruction_type.h" -#include "oneflow/core/vm/cpu_stream_type.h" +#ifndef ONEFLOW_CORE_COMMON_SINGLETON_PTR_H_ +#define ONEFLOW_CORE_COMMON_SINGLETON_PTR_H_ + +#include namespace oneflow { -namespace vm { -class CpuAccessBlobByCallbackInstructionType final : public AccessBlobByCallbackInstructionType { - public: - CpuAccessBlobByCallbackInstructionType() = default; - ~CpuAccessBlobByCallbackInstructionType() override = default; +namespace private_detail { + +template +const T* GlobalSingletonPtr() { + static std::unique_ptr value(new T()); + return value.get(); +} - using stream_type = vm::CpuStreamType; -}; -COMMAND(vm::RegisterInstructionType( - "cpu.AccessBlobByCallback")); +} // namespace private_detail + +template +const T* SingletonPtr() { + thread_local const T* value = private_detail::GlobalSingletonPtr(); + return value; +} -} // namespace vm } // namespace oneflow + +#endif // ONEFLOW_CORE_COMMON_SINGLETON_PTR_H_ diff --git a/oneflow/core/common/steady_vector.h b/oneflow/core/common/steady_vector.h new file mode 100644 index 00000000000..f2a7e06877a --- /dev/null +++ b/oneflow/core/common/steady_vector.h @@ -0,0 +1,102 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_CORE_COMMON_STEADY_VECTOR_H_ +#define ONEFLOW_CORE_COMMON_STEADY_VECTOR_H_ + +#include +#include +#include +#include +#include + +namespace oneflow { + +template +class SteadyVector { + public: + SteadyVector() : size_(0) {} + ~SteadyVector() = default; + + using value_type = const T; + using size_type = size_t; + + // thread safe. + size_t size() const { return size_; } + + // thread safe. + const T& at(size_t index) const { + CHECK_GE(index, 0); + CHECK_LT(index, size_); + return (*this)[index]; + } + + // thread safe. + const T& operator[](size_t index) const { + int gran = 0; + size_t start = 0; + GetGranularityAndStart(index, &gran, &start); + return granularity2data_[gran].get()[index - start]; + } + + void push_back(const T& elem) { *MutableOrAdd(size_) = elem; } + + // `index` shoule be <= size() + T* MutableOrAdd(size_t index) { + std::unique_lock lock(mutex_); + size_t size = size_; + CHECK_LE(index, size) << "index out of range"; + if (index == size) { + int granularity = GetGranularity(size); + if (size + 1 == (1 << granularity)) { + CHECK_LT(granularity, N); + granularity2data_[granularity].reset(new T[1 << granularity]); + } + ++size_; + } + return Mutable(index); + } + + private: + T* Mutable(size_t index) { + int gran = 0; + size_t start = 0; + GetGranularityAndStart(index, &gran, &start); + return &granularity2data_[gran].get()[index - start]; + } + + static void GetGranularityAndStart(size_t index, int* gran, size_t* start) { + *gran = GetGranularity(index); + *start = (1 << *gran) - 1; + } + +#ifdef __GNUC__ +#define LOG2(x) ((unsigned)(8 * sizeof(unsigned long long) - __builtin_clzll((x)) - 1)) +#else +#define LOG2(x) std::log2(x) +#endif + + static int GetGranularity(size_t index) { return LOG2(index + 1); } + +#undef LOG2 + + std::atomic size_; + std::mutex mutex_; + std::array, N> granularity2data_; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_COMMON_STEADY_VECTOR_H_ diff --git a/oneflow/core/vm/stream_runtime_desc.cpp b/oneflow/core/common/steady_vector_test.cpp similarity index 60% rename from oneflow/core/vm/stream_runtime_desc.cpp rename to oneflow/core/common/steady_vector_test.cpp index 68d2eff4a81..bfc5fdb19b8 100644 --- a/oneflow/core/vm/stream_runtime_desc.cpp +++ b/oneflow/core/common/steady_vector_test.cpp @@ -13,16 +13,24 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "oneflow/core/vm/stream_runtime_desc.h" +#include "gtest/gtest.h" +#include "oneflow/core/common/steady_vector.h" namespace oneflow { -namespace vm { +namespace test { -void StreamRtDesc::__Init__(StreamDesc* stream_desc) { - const StreamType* stream_type = &stream_desc->stream_type(); - reset_stream_desc(stream_desc); - set_stream_type(stream_type); +void TestSteadyVector(int granularity) { + CHECK_GT(granularity, 0); + SteadyVector vec; + ASSERT_EQ(vec.size(), 0); + for (int i = 0; i < (1 << granularity); ++i) { + vec.push_back(i); + ASSERT_EQ(vec.at(i), i); + ASSERT_EQ(vec.size(), i + 1); + } } -} // namespace vm +TEST(SteadyVector, simple) { TestSteadyVector(6); } + +} // namespace test } // namespace oneflow diff --git a/oneflow/core/common/stream_role.h b/oneflow/core/common/stream_role.h index 27fdd4256e0..9e7e5b47fa5 100644 --- a/oneflow/core/common/stream_role.h +++ b/oneflow/core/common/stream_role.h @@ -19,44 +19,44 @@ limitations under the License. #include #include #include "oneflow/core/common/preprocessor.h" +#include "glog/logging.h" namespace oneflow { -#define STREAM_ROLE_SEQ \ - OF_PP_MAKE_TUPLE_SEQ(kCompute) \ - OF_PP_MAKE_TUPLE_SEQ(kHost2Device) \ - OF_PP_MAKE_TUPLE_SEQ(kDevice2Host) \ - OF_PP_MAKE_TUPLE_SEQ(kSyncedLaunchedCommNet) \ - OF_PP_MAKE_TUPLE_SEQ(kAsyncedLaunchedCommNet) \ - OF_PP_MAKE_TUPLE_SEQ(kCriticalSection) - enum class StreamRole { kInvalid = 0, -#define DECLARE_STREAM_ROLE(stream_role) stream_role, - OF_PP_FOR_EACH_TUPLE(DECLARE_STREAM_ROLE, STREAM_ROLE_SEQ) -#undef DECLARE_STREAM_ROLE + kCompute, + kHost2Device, + kDevice2Host, + kSyncedLaunchedCommNet, + kAsyncedLaunchedCommNet, + kBarrier, + kCriticalSection, + kLazyJobLauncher }; -static constexpr int kStreamRoleSize = 1 + OF_PP_SEQ_SIZE(STREAM_ROLE_SEQ); - -// Act as a class for overloading functions -template -struct StreamRoleCase {}; - -template -auto StreamRoleSwitch(StreamRole stream_role, Args&&... args) - -> decltype(Functor::Case(StreamRoleCase(), - std::forward(args)...)) { - switch (stream_role) { -#define MAKE_ENTRY(stream_role) \ - case StreamRole::stream_role: \ - return Functor::Case(StreamRoleCase(), std::forward(args)...); - OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, STREAM_ROLE_SEQ) -#undef MAKE_ENTRY - default: - return Functor::Case(StreamRoleCase(), std::forward(args)...); +template +struct StreamRoleVisitor { + template + static auto Visit(StreamRole stream_role, Args&&... args) { + switch (stream_role) { + case StreamRole::kInvalid: LOG(FATAL) << "invalid stream role"; + case StreamRole::kCompute: return DerivedT::VisitCompute(std::forward(args)...); + case StreamRole::kHost2Device: return DerivedT::VisitHost2Device(std::forward(args)...); + case StreamRole::kDevice2Host: return DerivedT::VisitDevice2Host(std::forward(args)...); + case StreamRole::kSyncedLaunchedCommNet: + return DerivedT::VisitSyncedLaunchedCommNet(std::forward(args)...); + case StreamRole::kAsyncedLaunchedCommNet: + return DerivedT::VisitAsyncedLaunchedCommNet(std::forward(args)...); + case StreamRole::kBarrier: return DerivedT::VisitBarrier(std::forward(args)...); + case StreamRole::kCriticalSection: + return DerivedT::VisitCriticalSection(std::forward(args)...); + case StreamRole::kLazyJobLauncher: + return DerivedT::VisitLazyJobLauncher(std::forward(args)...); + } + LOG(FATAL) << "invalid stream role"; } -} +}; } // namespace oneflow diff --git a/oneflow/core/eager/blob_instruction_type.cpp b/oneflow/core/eager/blob_instruction_type.cpp index 3a4454ed8d7..65f04e2dbc9 100644 --- a/oneflow/core/eager/blob_instruction_type.cpp +++ b/oneflow/core/eager/blob_instruction_type.cpp @@ -46,7 +46,7 @@ void AccessBlobByCallbackInstructionType::ComputeInstrMsg( const auto* ptr = dynamic_cast(phy_instr_operand.get()); CHECK_NOTNULL(ptr); - DeviceCtx* device_ctx = instr_msg.phy_instr_stream()->device_ctx().get(); + DeviceCtx* device_ctx = instr_msg.stream().device_ctx().get(); auto* blob = ptr->eager_blob_object()->blob(); OfBlob ofblob(device_ctx->stream(), blob); ptr->callback()(reinterpret_cast(&ofblob)); diff --git a/oneflow/core/eager/blob_instruction_type.h b/oneflow/core/eager/blob_instruction_type.h index c3d1d6121b0..b2182dbf703 100644 --- a/oneflow/core/eager/blob_instruction_type.h +++ b/oneflow/core/eager/blob_instruction_type.h @@ -13,17 +13,28 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#ifndef ONEFLOW_CORE_EAGER_BLOB_INSTRUCTION_TYPE_H_ +#define ONEFLOW_CORE_EAGER_BLOB_INSTRUCTION_TYPE_H_ + #include "oneflow/core/intrusive/flat_msg_view.h" #include "oneflow/core/vm/instruction_type.h" +#include "oneflow/core/common/stream_role.h" +#include "oneflow/core/common/singleton_ptr.h" +#include "oneflow/core/vm/cuda_optional_event_record_status_querier.h" +#include "oneflow/core/vm/stream.h" +#include "oneflow/core/device/cuda_event.h" namespace oneflow { namespace vm { -class AccessBlobByCallbackInstructionType : public vm::InstructionType { +class AccessBlobByCallbackInstructionType final : public vm::InstructionType { public: AccessBlobByCallbackInstructionType() = default; ~AccessBlobByCallbackInstructionType() override = default; + std::string DebugName(const vm::InstructionMsg& instr_msg) const override { + return "AccessBlobByCallback"; + } void Compute(vm::Instruction* instruction) const override; void ComputeInFuseMode(vm::InstructionMsg* instruction_msg) const override; @@ -31,13 +42,86 @@ class AccessBlobByCallbackInstructionType : public vm::InstructionType { void ComputeInstrMsg(const vm::InstructionMsg& instruction_msg) const; }; -class RecordEventInstructionType : public vm::InstructionType { +class CpuRecordEventInstructionType final : public vm::InstructionType { + public: + CpuRecordEventInstructionType() = default; + ~CpuRecordEventInstructionType() override = default; + + std::string DebugName(const vm::InstructionMsg& instr_msg) const override { + return "RecordEvent"; + } + void Compute(vm::Instruction* instruction) const override {} +}; + +#ifdef WITH_CUDA + +class CudaRecordEventInstructionType final : public vm::InstructionType { public: - RecordEventInstructionType() = default; - ~RecordEventInstructionType() override = default; + CudaRecordEventInstructionType() = default; + ~CudaRecordEventInstructionType() override = default; + InstructionFuseType fuse_type() const override { return kEnableInstructionFuseAsTailOnly; } + + void InitInstructionStatus(Instruction* instruction) const override { + auto* status_buffer = instruction->mut_status_buffer(); + auto* stream = instruction->mut_stream(); + instruction->stream_type().InitInstructionStatus(*stream, status_buffer); + auto* event_provider = dynamic_cast(stream->device_ctx().get()); + const auto& cuda_event = CHECK_NOTNULL(event_provider)->GetCudaEvent(); + auto* data_ptr = status_buffer->mut_buffer()->mut_data(); + CudaOptionalEventRecordStatusQuerier::MutCast(data_ptr)->reset_cuda_event(cuda_event); + } + std::string DebugName(const vm::InstructionMsg& instr_msg) const override { + return "RecordEvent"; + } void Compute(vm::Instruction* instruction) const override {} }; +#endif + } // namespace vm + +struct GetRecordEventInstructionType : public StreamRoleVisitor { + static Maybe VisitCompute(DeviceType device_type) { + return GetInstructionType(device_type); + } + static Maybe VisitHost2Device(DeviceType device_type) { + return GetInstructionType(device_type); + } + static Maybe VisitDevice2Host(DeviceType device_type) { + return GetInstructionType(device_type); + } + static Maybe VisitSyncedLaunchedCommNet(DeviceType device_type) { + return GetInstructionType(device_type); + } + static Maybe VisitAsyncedLaunchedCommNet(DeviceType device_type) { + return GetInstructionType(device_type); + } + static Maybe VisitBarrier(DeviceType device_type) { + UNIMPLEMENTED_THEN_RETURN(); + } + static Maybe VisitCriticalSection(DeviceType device_type) { + UNIMPLEMENTED_THEN_RETURN(); + } + static Maybe VisitLazyJobLauncher(DeviceType device_type) { + UNIMPLEMENTED_THEN_RETURN(); + } + + private: + static Maybe GetInstructionType(DeviceType device_type) { + if (device_type == DeviceType::kCPU) { + return SingletonPtr(); + } else if (device_type == DeviceType::kCUDA) { +#ifdef WITH_CUDA + return SingletonPtr(); +#else + UNIMPLEMENTED_THEN_RETURN(); +#endif + } else { + UNIMPLEMENTED_THEN_RETURN(); + } + } +}; + } // namespace oneflow +#endif // ONEFLOW_CORE_EAGER_BLOB_INSTRUCTION_TYPE_H_ diff --git a/oneflow/core/eager/cpu_opkernel_instruction_type.cpp b/oneflow/core/eager/cpu_opkernel_instruction_type.cpp deleted file mode 100644 index 7d3ee257397..00000000000 --- a/oneflow/core/eager/cpu_opkernel_instruction_type.cpp +++ /dev/null @@ -1,36 +0,0 @@ -/* -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -#include "oneflow/core/common/util.h" -#include "oneflow/core/job/job_desc.h" -#include "oneflow/core/eager/opkernel_instruction_type.h" -#include "oneflow/core/vm/stream.h" -#include "oneflow/core/vm/cpu_stream_type.h" -#include "oneflow/core/vm/instruction.h" - -namespace oneflow { -namespace vm { - -class CpuLocalCallOpKernelInstructionType final : public LocalCallOpKernelInstructionType { - public: - CpuLocalCallOpKernelInstructionType() = default; - ~CpuLocalCallOpKernelInstructionType() override = default; - - using stream_type = vm::CpuStreamType; -}; -COMMAND(vm::RegisterInstructionType("cpu.LocalCallOpKernel")); - -} // namespace vm -} // namespace oneflow diff --git a/oneflow/core/eager/critical_section_instruction_type.cpp b/oneflow/core/eager/critical_section_instruction_type.h similarity index 92% rename from oneflow/core/eager/critical_section_instruction_type.cpp rename to oneflow/core/eager/critical_section_instruction_type.h index 1a4bd0b292d..f96b27b3e95 100644 --- a/oneflow/core/eager/critical_section_instruction_type.cpp +++ b/oneflow/core/eager/critical_section_instruction_type.h @@ -13,9 +13,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#ifndef ONEFLOW_CORE_EAGER_CRITICAL_SECTION_INSTRUCTION_TYPE_H_ +#define ONEFLOW_CORE_EAGER_CRITICAL_SECTION_INSTRUCTION_TYPE_H_ -#include "oneflow/core/eager/critical_section_stream_type.h" -#include "oneflow/core/eager/critical_section_status_querier.h" +#include "oneflow/core/vm/critical_section_status_querier.h" #include "oneflow/core/eager/critical_section_phy_instr_operand.h" #include "oneflow/core/job/critical_section_instance.h" #include "oneflow/core/framework/nn_graph_if.h" @@ -44,8 +45,9 @@ class CriticalSectionBeginInstructionType final : public InstructionType { CriticalSectionBeginInstructionType() = default; ~CriticalSectionBeginInstructionType() = default; - using stream_type = CriticalSectionStreamType; - + std::string DebugName(const vm::InstructionMsg& instr_msg) const override { + return "CriticalSectionBegin"; + } void Compute(vm::Instruction* instruction) const override { OF_PROFILER_RANGE_GUARD("CriticalSectionBegin"); { @@ -107,8 +109,6 @@ class CriticalSectionBeginInstructionType final : public InstructionType { } }; -COMMAND(RegisterInstructionType("CriticalSectionBegin")); - class CriticalSectionEndInstructionType final : public InstructionType { public: CriticalSectionEndInstructionType(const CriticalSectionEndInstructionType&) = delete; @@ -118,8 +118,9 @@ class CriticalSectionEndInstructionType final : public InstructionType { CriticalSectionEndInstructionType() = default; ~CriticalSectionEndInstructionType() = default; - using stream_type = CriticalSectionStreamType; - + std::string DebugName(const vm::InstructionMsg& instr_msg) const override { + return "CriticalSectionEnd"; + } void Compute(vm::Instruction* instruction) const override { const auto* ptr = instruction->instr_msg().phy_instr_operand().get(); const auto* phy_instr_operand = dynamic_cast(ptr); @@ -130,7 +131,6 @@ class CriticalSectionEndInstructionType final : public InstructionType { } }; -COMMAND(RegisterInstructionType("CriticalSectionEnd")); - } // namespace vm } // namespace oneflow +#endif // ONEFLOW_CORE_EAGER_CRITICAL_SECTION_INSTRUCTION_TYPE_H_ diff --git a/oneflow/core/eager/critical_section_phy_instr_operand.cpp b/oneflow/core/eager/critical_section_phy_instr_operand.cpp index ec6facb370d..bc4f2b7d21e 100644 --- a/oneflow/core/eager/critical_section_phy_instr_operand.cpp +++ b/oneflow/core/eager/critical_section_phy_instr_operand.cpp @@ -22,6 +22,7 @@ limitations under the License. #include "oneflow/core/device/ep_based_event_record.h" #include "oneflow/core/register/ofblob.h" #include "oneflow/core/common/container_util.h" +#include "oneflow/core/vm/stream.h" namespace oneflow { namespace vm { @@ -38,21 +39,9 @@ void CriticalSectionEndPhyInstrOperand::ForEachMirroredObject( DoEach(CHECK_JUST(eager_blob_object_->compute_local_dep_object())); } -namespace { - -Maybe RawCriticalSectionLocalDepObject() { - const auto& device = JUST(Device::New("cpu")); - return Stream::New(device, StreamRole::kCriticalSection)->mut_schedule_local_dep_object(); -} - -constexpr auto* CriticalSectionLocalDepObject = - DECORATE(&RawCriticalSectionLocalDepObject, ThreadLocal); - -} // namespace - void CriticalSectionBeginPhyInstrOperand::ForEachMutMirroredObject( const std::function& DoEach) const { - DoEach(CHECK_JUST(CriticalSectionLocalDepObject())); + DoEach(vm_stream_->schedule_local_dep_object().get()); } void CriticalSectionBeginPhyInstrOperand::FinishInvalidInterfaceEventRecords() { @@ -121,7 +110,7 @@ void OutputCriticalSectionBeginPhyInstrOperand::AccessBlobByOpName(uint64_t of_b void CriticalSectionEndPhyInstrOperand::ForEachMutMirroredObject( const std::function& DoEach) const { - DoEach(CHECK_JUST(CriticalSectionLocalDepObject())); + DoEach(vm_stream_->schedule_local_dep_object().get()); } } // namespace vm diff --git a/oneflow/core/eager/critical_section_phy_instr_operand.h b/oneflow/core/eager/critical_section_phy_instr_operand.h index f294dde1135..2627c3d6339 100644 --- a/oneflow/core/eager/critical_section_phy_instr_operand.h +++ b/oneflow/core/eager/critical_section_phy_instr_operand.h @@ -33,6 +33,8 @@ using EagerBlobObjectListPtr = namespace vm { +class Stream; + class CriticalSectionBeginPhyInstrOperand : public PhyInstrOperand { public: CriticalSectionBeginPhyInstrOperand(const CriticalSectionBeginPhyInstrOperand&) = delete; @@ -46,10 +48,12 @@ class CriticalSectionBeginPhyInstrOperand : public PhyInstrOperand { const std::shared_ptr& nn_graph, const one::EagerBlobObjectListPtr& eager_blob_objects, const std::shared_ptr>>& - op_name2end_event_record) + op_name2end_event_record, + vm::Stream* vm_stream) : nn_graph_(nn_graph), eager_blob_objects_(eager_blob_objects), - op_name2end_event_record_(op_name2end_event_record) {} + op_name2end_event_record_(op_name2end_event_record), + vm_stream_(vm_stream) {} const std::shared_ptr& nn_graph() const { return nn_graph_; } const one::EagerBlobObjectListPtr& eager_blob_objects() const { return eager_blob_objects_; } @@ -77,6 +81,7 @@ class CriticalSectionBeginPhyInstrOperand : public PhyInstrOperand { std::shared_ptr>> op_name2end_event_record_; HashMap op_name2interface_index_; + vm::Stream* vm_stream_; }; class InputCriticalSectionBeginPhyInstrOperand final : public CriticalSectionBeginPhyInstrOperand { @@ -85,8 +90,10 @@ class InputCriticalSectionBeginPhyInstrOperand final : public CriticalSectionBeg const std::shared_ptr& nn_graph, const one::EagerBlobObjectListPtr& eager_blob_objects, const std::shared_ptr>>& - op_name2end_event_record) - : CriticalSectionBeginPhyInstrOperand(nn_graph, eager_blob_objects, op_name2end_event_record), + op_name2end_event_record, + vm::Stream* vm_stream) + : CriticalSectionBeginPhyInstrOperand(nn_graph, eager_blob_objects, op_name2end_event_record, + vm_stream), input_dependences_(), output_dependences_() { ForEachConstMirroredObject(SetInserter(&input_dependences_)); @@ -141,8 +148,10 @@ class OutputCriticalSectionBeginPhyInstrOperand final : public CriticalSectionBe const std::shared_ptr& nn_graph, const one::EagerBlobObjectListPtr& eager_blob_objects, const std::shared_ptr>>& - op_name2end_event_record) - : CriticalSectionBeginPhyInstrOperand(nn_graph, eager_blob_objects, op_name2end_event_record), + op_name2end_event_record, + vm::Stream* vm_stream) + : CriticalSectionBeginPhyInstrOperand(nn_graph, eager_blob_objects, op_name2end_event_record, + vm_stream), input_dependences_(), output_dependences_() { ForEachConstMirroredObject(SetInserter(&input_dependences_)); @@ -195,8 +204,9 @@ class OutputCriticalSectionBeginPhyInstrOperand final : public CriticalSectionBe class CriticalSectionEndPhyInstrOperand : public PhyInstrOperand { public: CriticalSectionEndPhyInstrOperand(const std::shared_ptr& eager_blob_object, - const std::shared_ptr& event_record) - : eager_blob_object_(eager_blob_object), event_record_(event_record) {} + const std::shared_ptr& event_record, + vm::Stream* vm_stream) + : eager_blob_object_(eager_blob_object), event_record_(event_record), vm_stream_(vm_stream) {} virtual ~CriticalSectionEndPhyInstrOperand() = default; const std::shared_ptr& event_record() const { return event_record_; } @@ -208,13 +218,15 @@ class CriticalSectionEndPhyInstrOperand : public PhyInstrOperand { private: std::shared_ptr eager_blob_object_; std::shared_ptr event_record_; + vm::Stream* vm_stream_; }; class InputCriticalSecondEndPhyInstrOperand final : public CriticalSectionEndPhyInstrOperand { public: InputCriticalSecondEndPhyInstrOperand(const std::shared_ptr& eager_blob_object, - const std::shared_ptr& event_record) - : CriticalSectionEndPhyInstrOperand(eager_blob_object, event_record), + const std::shared_ptr& event_record, + vm::Stream* vm_stream) + : CriticalSectionEndPhyInstrOperand(eager_blob_object, event_record, vm_stream), input_dependences_(), output_dependences_() { ForEachConstMirroredObject(SetInserter(&input_dependences_)); @@ -241,8 +253,9 @@ class InputCriticalSecondEndPhyInstrOperand final : public CriticalSectionEndPhy class OutputCriticalSecondEndPhyInstrOperand final : public CriticalSectionEndPhyInstrOperand { public: OutputCriticalSecondEndPhyInstrOperand(const std::shared_ptr& eager_blob_object, - const std::shared_ptr& event_record) - : CriticalSectionEndPhyInstrOperand(eager_blob_object, event_record), + const std::shared_ptr& event_record, + vm::Stream* vm_stream) + : CriticalSectionEndPhyInstrOperand(eager_blob_object, event_record, vm_stream), input_dependences_(), output_dependences_() { ForEachConstMirroredObject(SetInserter(&input_dependences_)); diff --git a/oneflow/core/eager/cuda_blob_instruction_type.cpp b/oneflow/core/eager/cuda_blob_instruction_type.cpp deleted file mode 100644 index 940afcd6d16..00000000000 --- a/oneflow/core/eager/cuda_blob_instruction_type.cpp +++ /dev/null @@ -1,59 +0,0 @@ -/* -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -#include "oneflow/core/vm/cpu_stream_type.h" -#ifdef WITH_CUDA -#include "oneflow/core/eager/blob_instruction_type.h" -#include "oneflow/core/vm/cuda_stream_type.h" -#include "oneflow/core/vm/cuda_optional_event_record_status_querier.h" -#include "oneflow/core/vm/stream.h" -#include "oneflow/core/vm/async_cuda_stream_type.h" -#include "oneflow/core/device/cuda_event.h" - -namespace oneflow { -namespace vm { - -class GpuAccessBlobByCallbackInstructionType final : public AccessBlobByCallbackInstructionType { - public: - GpuAccessBlobByCallbackInstructionType() = default; - ~GpuAccessBlobByCallbackInstructionType() override = default; - using stream_type = vm::CudaStreamType; -}; -COMMAND(vm::RegisterInstructionType( - "cuda.AccessBlobByCallback")); - -class GpuRecordEventInstructionType : public RecordEventInstructionType { - public: - GpuRecordEventInstructionType() = default; - ~GpuRecordEventInstructionType() override = default; - using stream_type = vm::CudaStreamType; - - InstructionFuseType fuse_type() const override { return kEnableInstructionFuseAsTailOnly; } - - void InitInstructionStatus(Instruction* instruction) const override { - auto* status_buffer = instruction->mut_status_buffer(); - auto* stream = instruction->mut_stream(); - instruction->stream_type().InitInstructionStatus(*stream, status_buffer); - auto* event_provider = dynamic_cast(stream->device_ctx().get()); - const auto& cuda_event = CHECK_NOTNULL(event_provider)->GetCudaEvent(); - auto* data_ptr = status_buffer->mut_buffer()->mut_data(); - CudaOptionalEventRecordStatusQuerier::MutCast(data_ptr)->reset_cuda_event(cuda_event); - } -}; -COMMAND(vm::RegisterInstructionType("cuda.RecordEvent")); - -} // namespace vm -} // namespace oneflow -#endif diff --git a/oneflow/core/eager/cuda_opkernel_instruction_type.cpp b/oneflow/core/eager/cuda_opkernel_instruction_type.cpp deleted file mode 100644 index d6a431d02cd..00000000000 --- a/oneflow/core/eager/cuda_opkernel_instruction_type.cpp +++ /dev/null @@ -1,74 +0,0 @@ -/* -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -#ifdef WITH_CUDA - -#include "oneflow/core/common/util.h" -#include "oneflow/core/job/job_desc.h" -#include "oneflow/core/eager/opkernel_instruction_type.h" -#include "oneflow/core/vm/stream.h" -#include "oneflow/core/vm/cuda_stream_type.h" -#include "oneflow/core/vm/async_cuda_stream_type.h" -#include "oneflow/core/vm/cuda_copy_h2d_stream_type.h" -#include "oneflow/core/vm/cuda_copy_d2h_stream_type.h" -#include "oneflow/core/vm/instruction.h" - -namespace oneflow { -namespace vm { - -class CudaLocalCallOpKernelInstructionType final : public LocalCallOpKernelInstructionType { - public: - CudaLocalCallOpKernelInstructionType() = default; - ~CudaLocalCallOpKernelInstructionType() override = default; - - using stream_type = vm::CudaStreamType; -}; -COMMAND( - vm::RegisterInstructionType("cuda.LocalCallOpKernel")); - -class AsyncCudaLocalCallOpKernelInstructionType final : public LocalCallOpKernelInstructionType { - public: - AsyncCudaLocalCallOpKernelInstructionType() = default; - ~AsyncCudaLocalCallOpKernelInstructionType() override = default; - - using stream_type = vm::AsyncCudaStreamType; -}; -COMMAND(vm::RegisterInstructionType( - "async.cuda.LocalCallOpKernel")); - -class CudaH2DLocalCallOpKernelInstructionType final : public LocalCallOpKernelInstructionType { - public: - CudaH2DLocalCallOpKernelInstructionType() = default; - ~CudaH2DLocalCallOpKernelInstructionType() override = default; - - using stream_type = vm::CudaCopyH2DStreamType; -}; -COMMAND(vm::RegisterInstructionType( - "cuda_h2d.LocalCallOpKernel")); - -class CudaD2HLocalCallOpKernelInstructionType final : public LocalCallOpKernelInstructionType { - public: - CudaD2HLocalCallOpKernelInstructionType() = default; - ~CudaD2HLocalCallOpKernelInstructionType() override = default; - - using stream_type = vm::CudaCopyD2HStreamType; -}; -COMMAND(vm::RegisterInstructionType( - "cuda_d2h.LocalCallOpKernel")); - -} // namespace vm -} // namespace oneflow - -#endif diff --git a/oneflow/core/eager/eager_blob_object.h b/oneflow/core/eager/eager_blob_object.h index 6003b690f94..cb10a32c1d1 100644 --- a/oneflow/core/eager/eager_blob_object.h +++ b/oneflow/core/eager/eager_blob_object.h @@ -52,15 +52,15 @@ class TensorStorage { blob_bytes_ = bytes; } - const Optional>& producer_stream() const { return producer_stream_; } - Maybe init_producer_stream(Symbol producer_stream) { + const Optional>& producer_stream() const { return producer_stream_; } + Maybe init_producer_stream(Symbol<::oneflow::Stream> producer_stream) { CHECK_OR_RETURN(!producer_stream_.has_value()); producer_stream_ = producer_stream; return Maybe::Ok(); } - const Optional>& last_used_stream() const { return last_used_stream_; } - void set_last_used_stream(Symbol last_used_stream) { + const Optional>& last_used_stream() const { return last_used_stream_; } + void set_last_used_stream(Symbol<::oneflow::Stream> last_used_stream) { last_used_stream_ = last_used_stream; } @@ -77,8 +77,8 @@ class TensorStorage { size_t blob_bytes_; std::unique_ptr> blob_dptr_; std::unique_ptr non_pod_allocator_; - Optional> producer_stream_; - Optional> last_used_stream_; + Optional> producer_stream_; + Optional> last_used_stream_; std::vector> storage_delete_hooks_; }; @@ -125,17 +125,17 @@ class EagerBlobObject final { void set_is_shape_synced(bool val) { is_shape_synced_ = val; } - const Optional>& producer_stream() const { + const Optional>& producer_stream() const { return tensor_storage_->producer_stream(); } - Maybe init_producer_stream(Symbol producer_stream) { + Maybe init_producer_stream(Symbol<::oneflow::Stream> producer_stream) { return tensor_storage_->init_producer_stream(producer_stream); } - const Optional>& last_used_stream() const { + const Optional>& last_used_stream() const { return tensor_storage_->last_used_stream(); } - void set_last_used_stream(Symbol last_used_stream) { + void set_last_used_stream(Symbol<::oneflow::Stream> last_used_stream) { tensor_storage_->set_last_used_stream(last_used_stream); } diff --git a/oneflow/core/eager/lazy_job_instruction_type.cpp b/oneflow/core/eager/lazy_job_instruction_type.h similarity index 93% rename from oneflow/core/eager/lazy_job_instruction_type.cpp rename to oneflow/core/eager/lazy_job_instruction_type.h index 369d602e70e..b2b8949fff3 100644 --- a/oneflow/core/eager/lazy_job_instruction_type.cpp +++ b/oneflow/core/eager/lazy_job_instruction_type.h @@ -13,9 +13,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#ifndef ONEFLOW_CORE_EAGER_LAZY_JOB_INSTRUCTION_TYPE_H_ +#define ONEFLOW_CORE_EAGER_LAZY_JOB_INSTRUCTION_TYPE_H_ -#include "oneflow/core/eager/lazy_job_stream_type.h" -#include "oneflow/core/eager/lazy_job_device_context.h" +#include "oneflow/core/vm/lazy_job_device_context.h" #include "oneflow/core/eager/lazy_job_phy_instr_operand.h" #include "oneflow/core/framework/nn_graph_if.h" #include "oneflow/core/common/container_util.h" @@ -33,8 +34,6 @@ limitations under the License. namespace oneflow { -namespace { - class LazyJobInstance final : public JobInstance { public: LazyJobInstance(const LazyJobInstance&) = delete; @@ -62,8 +61,6 @@ class LazyJobInstance final : public JobInstance { const std::function finish_cb_; }; -} // namespace - namespace vm { class LaunchLazyJobInstructionType final : public InstructionType { // NOLINT @@ -72,7 +69,10 @@ class LaunchLazyJobInstructionType final : public InstructionType { // NOLINT LaunchLazyJobInstructionType(LaunchLazyJobInstructionType&&) = delete; LaunchLazyJobInstructionType() = default; ~LaunchLazyJobInstructionType() = default; - using stream_type = LazyJobStreamType; + + std::string DebugName(const vm::InstructionMsg& instr_msg) const override { + return "LaunchLazyJob"; + } void Compute(vm::Instruction* instruction) const override { const auto& cur_nn_graph = GetCurNNGraph(instruction); auto* device_ctx = GetLazyJobDeviceCtx(instruction); @@ -127,7 +127,6 @@ class LaunchLazyJobInstructionType final : public InstructionType { // NOLINT } }; -COMMAND(RegisterInstructionType("LaunchLazyJob")); - } // namespace vm } // namespace oneflow +#endif // ONEFLOW_CORE_EAGER_LAZY_JOB_INSTRUCTION_TYPE_H_ diff --git a/oneflow/core/eager/lazy_job_phy_instr_operand.cpp b/oneflow/core/eager/lazy_job_phy_instr_operand.cpp index 4eed1c2e3ea..ab9c2c1c375 100644 --- a/oneflow/core/eager/lazy_job_phy_instr_operand.cpp +++ b/oneflow/core/eager/lazy_job_phy_instr_operand.cpp @@ -13,42 +13,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "oneflow/core/common/decorator.h" #include "oneflow/core/eager/lazy_job_phy_instr_operand.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/framework/stream.h" +#include "oneflow/core/vm/virtual_machine.h" namespace oneflow { namespace vm { -namespace { - -#ifdef WITH_CUDA -Maybe RawGetEagerNcclLocalDepObject(StreamRole stream_role) { - // NOTE(chengcheng): - // Lazy Job instruction need mutual exclusion nccl with Eager nccl. However, when the number of - // processes is more than the number of physical GPUs, the following processes will make an - // error when using local rank to create a EagerNcclLocalDepObject, but we only need an legal - // device so we use device 0. - const auto& device = JUST(Device::New("cpu", 0)); - const auto& stream = Stream::New(device, stream_role); - const auto& local_dep_object = stream->mut_transport_local_dep_object(); - CHECK_OR_RETURN(local_dep_object.has_value()); - return JUST(local_dep_object); -} - -static constexpr auto* GetEagerNcclLocalDepObject = - DECORATE(&RawGetEagerNcclLocalDepObject, ThreadLocalCopiable); -#endif // WITH_CUDA - -} // namespace - void LaunchLazyJobPhyInstrOperand::ForEachMutMirroredObject( const std::function& DoEach) const { for (const auto& eager_blob_object : *param_blob_objects_) { DoEach(CHECK_JUST(eager_blob_object->compute_local_dep_object())); } - DoEach(GetStaticGlobalTransportLocalDepObject()); + DoEach( + CHECK_JUST(GlobalMaybe())->FindOrCreateTransportLocalDepObject().Mutable()); } } // namespace vm diff --git a/oneflow/core/eager/opkernel_instruction_type.cpp b/oneflow/core/eager/op_call_instruction_type.cpp similarity index 77% rename from oneflow/core/eager/opkernel_instruction_type.cpp rename to oneflow/core/eager/op_call_instruction_type.cpp index 89f3c341fd4..6381137fc80 100644 --- a/oneflow/core/eager/opkernel_instruction_type.cpp +++ b/oneflow/core/eager/op_call_instruction_type.cpp @@ -23,9 +23,8 @@ limitations under the License. #include "oneflow/core/eager/eager_blob_object.h" #include "oneflow/core/vm/stream.h" #include "oneflow/core/vm/thread_ctx.h" -#include "oneflow/core/vm/cuda_stream_type.h" -#include "oneflow/core/eager/opkernel_instruction_type.h" -#include "oneflow/core/eager/local_call_opkernel_phy_instr_operand.h" +#include "oneflow/core/eager/op_call_instruction_type.h" +#include "oneflow/core/eager/op_call_phy_instr_operand.h" #include "oneflow/core/vm/instruction.h" #include "oneflow/core/vm/instruction_type.h" #include "oneflow/core/framework/user_op_registry_manager.h" @@ -33,7 +32,7 @@ limitations under the License. #include "oneflow/core/register/ofblob.h" #include "oneflow/core/vm/symbol_storage.h" #include "oneflow/core/operator/op_conf_symbol.h" -#include "oneflow/user/kernels/stateful_local_opkernel.h" +#include "oneflow/user/kernels/stateful_opkernel.h" #include "oneflow/core/profiler/profiler.h" #include "oneflow/core/profiler/collection.h" #include "oneflow/core/common/cpp_attribute.h" @@ -41,12 +40,12 @@ limitations under the License. namespace oneflow { namespace vm { -struct LocalCallOpKernelUtil final { +struct OpCallInstructionUtil final { static inline Maybe Compute(const vm::InstructionMsg& instr_msg) { OF_PROFILER_RANGE_PUSH("ResetPrior"); - auto* operand = LocalCallOpKernelUtil::GetLocalCallOpKernelPhyInstrOperand(instr_msg); + auto* operand = OpCallInstructionUtil::GetCallPhyInstrOperand(instr_msg); operand->mut_opkernel()->composed_attrs_for_scheduler_thread()->ResetPrior(operand->attrs()); - DeviceCtx* device_ctx = instr_msg.phy_instr_stream()->device_ctx().get(); + DeviceCtx* device_ctx = instr_msg.stream().device_ctx().get(); OF_PROFILER_RANGE_POP(); OF_PROFILER_RANGE_PUSH("AllocateOutputBlobsMemory"); JUST(AllocateOutputBlobsMemory(operand, device_ctx)); @@ -70,14 +69,13 @@ struct LocalCallOpKernelUtil final { return Maybe::Ok(); } - static inline LocalCallOpKernelPhyInstrOperand* GetLocalCallOpKernelPhyInstrOperand( - const vm::InstructionMsg& instr_msg) { + static inline OpCallPhyInstrOperand* GetCallPhyInstrOperand(const vm::InstructionMsg& instr_msg) { auto* operand = CHECK_NOTNULL(instr_msg.phy_instr_operand().get()); - return CHECK_NOTNULL(dynamic_cast(operand)); + return CHECK_NOTNULL(dynamic_cast(operand)); } private: - static inline void InferTempStorageBlobDesc(LocalCallOpKernelPhyInstrOperand* operand) { + static inline void InferTempStorageBlobDesc(OpCallPhyInstrOperand* operand) { const auto& InferTmpSizeFn = operand->opkernel().GetInferTmpSizeFn(operand->user_opkernel()); auto* temp_eager_blob_object = operand->mut_opkernel()->mut_temp_blob_object(); CHECK(temp_eager_blob_object->data_type() == DataType::kChar); @@ -93,7 +91,7 @@ struct LocalCallOpKernelUtil final { op_infer_ctx->Update(nullptr, nullptr, nullptr); } - static inline void TryInitOpKernelStateAndCache(LocalCallOpKernelPhyInstrOperand* operand, + static inline void TryInitOpKernelStateAndCache(OpCallPhyInstrOperand* operand, DeviceCtx* device_ctx, user_op::OpKernelState** state, user_op::OpKernelCache** cache) { @@ -108,7 +106,7 @@ struct LocalCallOpKernelUtil final { operand->consistent_tensor_infer_result().get(), state, cache); } - static inline Maybe AllocateOutputBlobsMemory(LocalCallOpKernelPhyInstrOperand* operand, + static inline Maybe AllocateOutputBlobsMemory(OpCallPhyInstrOperand* operand, DeviceCtx* device_ctx) { for (const auto& blob_object : *operand->outputs()) { JUST(blob_object->TryAllocateBlobBodyMemory(device_ctx)); @@ -116,13 +114,13 @@ struct LocalCallOpKernelUtil final { return Maybe::Ok(); } - static inline Maybe TryAllocateTempStorageBlobMemory( - LocalCallOpKernelPhyInstrOperand* operand, DeviceCtx* device_ctx) { + static inline Maybe TryAllocateTempStorageBlobMemory(OpCallPhyInstrOperand* operand, + DeviceCtx* device_ctx) { return operand->mut_opkernel()->mut_temp_blob_object()->TryAllocateBlobBodyMemory(device_ctx); } - static inline void OpKernelCompute(LocalCallOpKernelPhyInstrOperand* operand, - DeviceCtx* device_ctx, user_op::OpKernelState* state, + static inline void OpKernelCompute(OpCallPhyInstrOperand* operand, DeviceCtx* device_ctx, + user_op::OpKernelState* state, const user_op::OpKernelCache* cache) { auto* opkernel = operand->mut_opkernel(); auto* compute_ctx = @@ -161,30 +159,28 @@ struct LocalCallOpKernelUtil final { operand->user_opkernel()->Compute(compute_ctx, state, cache); } OF_PROFILER_RANGE_POP(); - // tensor tuples are not allowed to be hold by StatefulLocalOpKernel + // tensor tuples are not allowed to be hold by StatefulOpKernel opkernel->UpdateComputeContext(nullptr, nullptr, nullptr, nullptr); } - static inline Maybe DeallocateTempStorageBlobMemory( - LocalCallOpKernelPhyInstrOperand* operand, DeviceCtx* device_ctx) { + static inline Maybe DeallocateTempStorageBlobMemory(OpCallPhyInstrOperand* operand, + DeviceCtx* device_ctx) { return operand->mut_opkernel()->mut_temp_blob_object()->DeallocateBlobDataPtr(); } }; -void LocalCallOpKernelInstructionType::Compute(vm::Instruction* instruction) const { - CHECK_JUST(LocalCallOpKernelUtil::Compute(instruction->instr_msg())); +void OpCallInstructionType::Compute(vm::Instruction* instruction) const { + CHECK_JUST(OpCallInstructionUtil::Compute(instruction->instr_msg())); } -void LocalCallOpKernelInstructionType::ComputeInFuseMode(vm::InstructionMsg* instr_msg) const { - CHECK_JUST(LocalCallOpKernelUtil::Compute(*instr_msg)); +void OpCallInstructionType::ComputeInFuseMode(vm::InstructionMsg* instr_msg) const { + CHECK_JUST(OpCallInstructionUtil::Compute(*instr_msg)); } -std::string LocalCallOpKernelInstructionType::DebugOpTypeName( - const vm::InstructionMsg& instr_msg) const { +std::string OpCallInstructionType::DebugName(const vm::InstructionMsg& instr_msg) const { auto* operand = CHECK_NOTNULL(instr_msg.phy_instr_operand().get()); - return CHECK_NOTNULL(dynamic_cast(operand)) - ->opkernel() - .op_type_name(); + return CHECK_NOTNULL(dynamic_cast(operand))->opkernel().op_type_name() + + ":Call"; } } // namespace vm diff --git a/oneflow/core/eager/opkernel_instruction_type.h b/oneflow/core/eager/op_call_instruction_type.h similarity index 70% rename from oneflow/core/eager/opkernel_instruction_type.h rename to oneflow/core/eager/op_call_instruction_type.h index bc860a6df05..31aacb6fd7b 100644 --- a/oneflow/core/eager/opkernel_instruction_type.h +++ b/oneflow/core/eager/op_call_instruction_type.h @@ -13,10 +13,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifndef ONEFLOW_CORE_EAGER_CALL_OPKERNEL_INSTRUCTION_H_ -#define ONEFLOW_CORE_EAGER_CALL_OPKERNEL_INSTRUCTION_H_ +#ifndef ONEFLOW_CORE_EAGER_OP_CALL_INSTRUCTION_TYPE_H_ +#define ONEFLOW_CORE_EAGER_OP_CALL_INSTRUCTION_TYPE_H_ -#include "oneflow/core/vm/instr_type_id.h" #include "oneflow/core/vm/instruction.h" #include "oneflow/core/vm/instruction_type.h" #include "oneflow/core/memory/memory_case.pb.h" @@ -24,19 +23,19 @@ limitations under the License. namespace oneflow { namespace vm { -class LocalCallOpKernelInstructionType : public vm::InstructionType { +class OpCallInstructionType final : public vm::InstructionType { public: + OpCallInstructionType() = default; + ~OpCallInstructionType() = default; + void Compute(vm::Instruction* instruction) const override; void ComputeInFuseMode(vm::InstructionMsg* instr_msg) const override; InstructionFuseType fuse_type() const override { return kEnableInstructionFuseAtAnyPosition; } - std::string DebugOpTypeName(const vm::InstructionMsg& instr_msg) const override; + std::string DebugName(const vm::InstructionMsg& instr_msg) const override; protected: - LocalCallOpKernelInstructionType() = default; - virtual ~LocalCallOpKernelInstructionType() = default; - private: Maybe MaybeCompute(vm::Instruction* instruction) const; }; @@ -44,4 +43,4 @@ class LocalCallOpKernelInstructionType : public vm::InstructionType { } // namespace vm } // namespace oneflow -#endif // ONEFLOW_CORE_EAGER_CALL_OPKERNEL_INSTRUCTION_H_ +#endif // ONEFLOW_CORE_EAGER_OP_CALL_INSTRUCTION_TYPE_H_ diff --git a/oneflow/core/eager/local_call_opkernel_phy_instr_operand.cpp b/oneflow/core/eager/op_call_phy_instr_operand.cpp similarity index 78% rename from oneflow/core/eager/local_call_opkernel_phy_instr_operand.cpp rename to oneflow/core/eager/op_call_phy_instr_operand.cpp index 07250c580ae..cd553b59a54 100644 --- a/oneflow/core/eager/local_call_opkernel_phy_instr_operand.cpp +++ b/oneflow/core/eager/op_call_phy_instr_operand.cpp @@ -13,21 +13,22 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "oneflow/core/eager/local_call_opkernel_phy_instr_operand.h" -#include "oneflow/user/kernels/stateful_local_opkernel.h" +#include "oneflow/core/eager/op_call_phy_instr_operand.h" +#include "oneflow/user/kernels/stateful_opkernel.h" #include "oneflow/core/eager/dev_vm_dep_object_consume_mode.h" #include "oneflow/core/framework/stream_is_comm_net_stream.h" +#include "oneflow/core/vm/stream.h" namespace oneflow { namespace vm { -Maybe LocalCallOpKernelPhyInstrOperand::Init() { +Maybe OpCallPhyInstrOperand::Init() { JUST(mut_opkernel()->ChooseOpKernel(&user_opkernel_, &need_temp_storage_, attrs(), inputs().get(), outputs().get(), consistent_tensor_infer_result().get())); return Maybe::Ok(); } -void LocalCallOpKernelPhyInstrOperand::ForEachConstMirroredObject( +void OpCallPhyInstrOperand::ForEachConstMirroredObject( const std::function& DoEach) const { const auto& input_list = inputs(); for (int64_t index : opkernel().input_tuple_indexes4const_ibns()) { @@ -36,10 +37,9 @@ void LocalCallOpKernelPhyInstrOperand::ForEachConstMirroredObject( } } -void LocalCallOpKernelPhyInstrOperand::InitStreamSequentialDependence() { - const auto& stream = opkernel().stream(); - auto* device_schedule_dep_object = stream->mut_schedule_local_dep_object(); - if (StreamRoleSwitch(stream->stream_role())) { +void OpCallPhyInstrOperand::InitStreamSequentialDependence() { + auto* device_schedule_dep_object = vm_stream_->schedule_local_dep_object().get(); + if (IsCommNetStream::Visit(vm_stream_->stream_role())) { // Sequantialize nccl instructions to avoid deadlock stream_sequential_dependence_ = device_schedule_dep_object; } else { @@ -53,11 +53,10 @@ void LocalCallOpKernelPhyInstrOperand::InitStreamSequentialDependence() { } } -void LocalCallOpKernelPhyInstrOperand::ForEachMutMirroredObject( +void OpCallPhyInstrOperand::ForEachMutMirroredObject( const std::function& DoEach) const { - const auto& stream = opkernel().stream(); - const auto& opt_transport_dep_object = stream->mut_transport_local_dep_object(); - if (opt_transport_dep_object.has_value()) { DoEach(CHECK_JUST(opt_transport_dep_object)); } + const auto& opt_transport_dep_object = vm_stream_->transport_local_dep_object(); + if (opt_transport_dep_object.has_value()) { DoEach(CHECK_JUST(opt_transport_dep_object)->get()); } const auto& input_list = inputs(); for (int64_t index : opkernel().input_tuple_indexes4mut_ibns()) { @@ -71,7 +70,7 @@ void LocalCallOpKernelPhyInstrOperand::ForEachMutMirroredObject( } } -void LocalCallOpKernelPhyInstrOperand::ForEachMut2MirroredObject( +void OpCallPhyInstrOperand::ForEachMut2MirroredObject( const std::function& DoEach) const { const auto& output_list = outputs(); for (int64_t index : opkernel().output_tuple_indexes4mut2_obns()) { diff --git a/oneflow/core/eager/local_call_opkernel_phy_instr_operand.h b/oneflow/core/eager/op_call_phy_instr_operand.h similarity index 78% rename from oneflow/core/eager/local_call_opkernel_phy_instr_operand.h rename to oneflow/core/eager/op_call_phy_instr_operand.h index 90cec6beb18..3a67d1f5995 100644 --- a/oneflow/core/eager/local_call_opkernel_phy_instr_operand.h +++ b/oneflow/core/eager/op_call_phy_instr_operand.h @@ -13,8 +13,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifndef ONEFLOW_CORE_EAGER_LOCAL_CALL_OPKERNEL_PHY_INSTR_OPERAND_H_ -#define ONEFLOW_CORE_EAGER_LOCAL_CALL_OPKERNEL_PHY_INSTR_OPERAND_H_ +#ifndef ONEFLOW_CORE_EAGER_OP_CALL_PHY_INSTR_OPERAND_H_ +#define ONEFLOW_CORE_EAGER_OP_CALL_PHY_INSTR_OPERAND_H_ #include "oneflow/core/vm/phy_instr_operand.h" #include "oneflow/core/eager/dev_vm_dep_object_consume_mode.h" @@ -23,9 +23,14 @@ limitations under the License. #include "oneflow/core/framework/op_interpreter.h" namespace oneflow { + +namespace vm { +class Stream; +} + namespace one { -class StatefulLocalOpKernel; +class StatefulOpKernel; class ConsistentTensorInferResult; using EagerBlobObjectList = std::vector>; @@ -42,20 +47,20 @@ class OpKernel; namespace vm { -class LocalCallOpKernelPhyInstrOperand final : public vm::PhyInstrOperand { +class OpCallPhyInstrOperand final : public vm::PhyInstrOperand { public: - LocalCallOpKernelPhyInstrOperand(const LocalCallOpKernelPhyInstrOperand&) = delete; - LocalCallOpKernelPhyInstrOperand(LocalCallOpKernelPhyInstrOperand&&) = delete; - ~LocalCallOpKernelPhyInstrOperand() override = default; + OpCallPhyInstrOperand(const OpCallPhyInstrOperand&) = delete; + OpCallPhyInstrOperand(OpCallPhyInstrOperand&&) = delete; + ~OpCallPhyInstrOperand() override = default; template - static Maybe New(Args&&... args) { - auto* ptr = new LocalCallOpKernelPhyInstrOperand(std::forward(args)...); + static Maybe New(Args&&... args) { + auto* ptr = new OpCallPhyInstrOperand(std::forward(args)...); JUST(ptr->Init()); - return std::shared_ptr(ptr); + return std::shared_ptr(ptr); } - const one::StatefulLocalOpKernel& opkernel() const { return *opkernel_; } + const one::StatefulOpKernel& opkernel() const { return *opkernel_; } const one::EagerBlobObjectListPtr& inputs() const { return inputs_; } const one::EagerBlobObjectListPtr& outputs() const { return outputs_; } const AttrMap& attrs() const { return op_interp_ctx_.attrs; } @@ -64,7 +69,7 @@ class LocalCallOpKernelPhyInstrOperand final : public vm::PhyInstrOperand { return dev_vm_dep_object_consume_mode_; } - one::StatefulLocalOpKernel* mut_opkernel() { return opkernel_.get(); } + one::StatefulOpKernel* mut_opkernel() { return opkernel_.get(); } template Maybe ForEachOutputTensor(const DoEachT& DoEach) { @@ -90,13 +95,14 @@ class LocalCallOpKernelPhyInstrOperand final : public vm::PhyInstrOperand { } private: - LocalCallOpKernelPhyInstrOperand( - const std::shared_ptr& opkernel, + OpCallPhyInstrOperand( + vm::Stream* vm_stream, const std::shared_ptr& opkernel, const one::EagerBlobObjectListPtr& inputs, const one::EagerBlobObjectListPtr& outputs, const std::shared_ptr& consistent_tensor_infer_result, const one::OpExprInterpContext& op_interp_ctx_, const one::DevVmDepObjectConsumeMode dev_vm_dep_object_consume_mode) - : opkernel_(opkernel), + : vm_stream_(vm_stream), + opkernel_(opkernel), inputs_(inputs), outputs_(outputs), consistent_tensor_infer_result_(consistent_tensor_infer_result), @@ -113,7 +119,8 @@ class LocalCallOpKernelPhyInstrOperand final : public vm::PhyInstrOperand { Maybe Init(); void InitStreamSequentialDependence(); - std::shared_ptr opkernel_; + vm::Stream* vm_stream_; + std::shared_ptr opkernel_; one::EagerBlobObjectListPtr inputs_; one::EagerBlobObjectListPtr outputs_; std::shared_ptr consistent_tensor_infer_result_; @@ -128,4 +135,4 @@ class LocalCallOpKernelPhyInstrOperand final : public vm::PhyInstrOperand { } // namespace vm } // namespace oneflow -#endif // ONEFLOW_CORE_EAGER_LOCAL_CALL_OPKERNEL_PHY_INSTR_OPERAND_H_ +#endif // ONEFLOW_CORE_EAGER_OP_CALL_PHY_INSTR_OPERAND_H_ diff --git a/oneflow/core/eager/release_tensor_arg_phy_instr_operand.h b/oneflow/core/eager/release_tensor_arg_phy_instr_operand.h index 742847f4c1c..f958a087cde 100644 --- a/oneflow/core/eager/release_tensor_arg_phy_instr_operand.h +++ b/oneflow/core/eager/release_tensor_arg_phy_instr_operand.h @@ -26,6 +26,7 @@ limitations under the License. #include "oneflow/core/common/optional.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/framework/stream.h" +#include "oneflow/core/vm/stream.h" namespace oneflow { @@ -36,11 +37,11 @@ class EagerBlobObject; class ReleaseTensorArgPhyInstrOperand : public PhyInstrOperand { public: ReleaseTensorArgPhyInstrOperand(const std::shared_ptr& eager_blob_object, - const Optional>& stream) + const Optional& stream) : eager_blob_object_(eager_blob_object), output_dependences_() { output_dependences_.push_back(CHECK_JUST(eager_blob_object->compute_local_dep_object())); if (stream.has_value()) { - stream_sequential_dependence_ = CHECK_JUST(stream)->mut_schedule_local_dep_object(); + stream_sequential_dependence_ = CHECK_JUST(stream)->schedule_local_dep_object().get(); } } ~ReleaseTensorArgPhyInstrOperand() override = default; diff --git a/oneflow/core/eager/release_tensor_instruction_type.cpp b/oneflow/core/eager/release_tensor_instruction_type.h similarity index 53% rename from oneflow/core/eager/release_tensor_instruction_type.cpp rename to oneflow/core/eager/release_tensor_instruction_type.h index 682b04587b6..427581a1d08 100644 --- a/oneflow/core/eager/release_tensor_instruction_type.cpp +++ b/oneflow/core/eager/release_tensor_instruction_type.h @@ -13,28 +13,26 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#ifndef ONEFLOW_CORE_EAGER_RELEASE_TENSOR_INSTRUCTION_TYPE_H_ +#define ONEFLOW_CORE_EAGER_RELEASE_TENSOR_INSTRUCTION_TYPE_H_ + #include "oneflow/core/vm/instruction.h" +#include "oneflow/core/vm/instruction_type.h" #include "oneflow/core/eager/release_tensor_arg_phy_instr_operand.h" #include "oneflow/core/eager/eager_blob_object.h" -#include "oneflow/core/vm/cuda_stream_type.h" -#include "oneflow/core/vm/async_cuda_stream_type.h" -#include "oneflow/core/vm/cuda_copy_h2d_stream_type.h" -#include "oneflow/core/vm/cuda_copy_d2h_stream_type.h" -#include "oneflow/core/vm/cpu_stream_type.h" #include "oneflow/core/vm/cuda_optional_event_record_status_querier.h" +#include "oneflow/core/common/stream_role.h" +#include "oneflow/core/common/singleton_ptr.h" namespace oneflow { namespace vm { -template class ReleaseTensorInstructionType : public vm::InstructionType { public: ReleaseTensorInstructionType() = default; ~ReleaseTensorInstructionType() override = default; - using stream_type = StreamT; - InstructionFuseType fuse_type() const override { return kEnableInstructionFuseAtAnyPosition; } void Release(const vm::InstructionMsg& instr_msg) const { @@ -45,19 +43,16 @@ class ReleaseTensorInstructionType : public vm::InstructionType { CHECK_NOTNULL(ptr); CHECK_JUST(ptr->eager_blob_object()->DeallocateBlobDataPtr()); } + std::string DebugName(const vm::InstructionMsg& instr_msg) const override { + return "ReleaseTensor"; + } void Compute(vm::Instruction* instruction) const override { Release(instruction->instr_msg()); } void ComputeInFuseMode(vm::InstructionMsg* instr_msg) const override { Release(*instr_msg); } }; -COMMAND( - vm::RegisterInstructionType>("cpu.ReleaseTensor")); -COMMAND(vm::RegisterInstructionType>( - "comm_net.ReleaseTensor")); - #ifdef WITH_CUDA -template -class CudaReleaseTensorInstructionType : public ReleaseTensorInstructionType { +class CudaReleaseTensorInstructionType : public ReleaseTensorInstructionType { public: CudaReleaseTensorInstructionType() = default; ~CudaReleaseTensorInstructionType() override = default; @@ -71,17 +66,51 @@ class CudaReleaseTensorInstructionType : public ReleaseTensorInstructionType>( - "cuda.ReleaseTensor")); -COMMAND(vm::RegisterInstructionType>( - "cuda_h2d.ReleaseTensor")); -COMMAND(vm::RegisterInstructionType>( - "cuda_d2h.ReleaseTensor")); -COMMAND(vm::RegisterInstructionType>( - "sync_launched_nccl.ReleaseTensor")); -COMMAND(vm::RegisterInstructionType>( - "async_launched_nccl.ReleaseTensor")); #endif } // namespace vm + +struct GetReleaseInstructionType : public StreamRoleVisitor { + static Maybe VisitCompute(DeviceType device_type) { + return GetInstructionType(device_type); + } + static Maybe VisitHost2Device(DeviceType device_type) { + return GetInstructionType(device_type); + } + static Maybe VisitDevice2Host(DeviceType device_type) { + return GetInstructionType(device_type); + } + static Maybe VisitSyncedLaunchedCommNet(DeviceType device_type) { + return GetInstructionType(device_type); + } + static Maybe VisitAsyncedLaunchedCommNet(DeviceType device_type) { + return GetInstructionType(device_type); + } + static Maybe VisitBarrier(DeviceType device_type) { + UNIMPLEMENTED_THEN_RETURN(); + } + static Maybe VisitCriticalSection(DeviceType device_type) { + UNIMPLEMENTED_THEN_RETURN(); + } + static Maybe VisitLazyJobLauncher(DeviceType device_type) { + UNIMPLEMENTED_THEN_RETURN(); + } + + private: + static Maybe GetInstructionType(DeviceType device_type) { + if (device_type == DeviceType::kCPU) { + return SingletonPtr(); + } else if (device_type == DeviceType::kCUDA) { +#ifdef WITH_CUDA + return SingletonPtr(); +#else + UNIMPLEMENTED_THEN_RETURN(); +#endif + } else { + UNIMPLEMENTED_THEN_RETURN(); + } + } +}; + } // namespace oneflow +#endif // ONEFLOW_CORE_EAGER_RELEASE_TENSOR_INSTRUCTION_TYPE_H_ diff --git a/oneflow/core/framework/instructions_builder.cpp b/oneflow/core/framework/instructions_builder.cpp index 6d2121bb5b2..f3b15dcd15c 100644 --- a/oneflow/core/framework/instructions_builder.cpp +++ b/oneflow/core/framework/instructions_builder.cpp @@ -26,21 +26,25 @@ limitations under the License. #include "oneflow/core/common/container_util.h" #include "oneflow/core/common/decorator.h" #include "oneflow/core/common/blocking_counter.h" +#include "oneflow/core/common/singleton_ptr.h" #include "oneflow/core/rpc/include/global_process_ctx.h" #include "oneflow/core/vm/barrier_phy_instr_operand.h" #include "oneflow/core/vm/access_blob_arg_cb_phy_instr_operand.h" #include "oneflow/core/vm/consume_local_dep_object_phy_instr_operand.h" -#include "oneflow/core/eager/release_tensor_arg_phy_instr_operand.h" +#include "oneflow/core/eager/release_tensor_instruction_type.h" +#include "oneflow/core/eager/blob_instruction_type.h" +#include "oneflow/core/eager/op_call_instruction_type.h" +#include "oneflow/core/vm/barrier_instruction_type.h" #include "oneflow/core/vm/virtual_machine.h" #include "oneflow/core/vm/vm_util.h" #include "oneflow/core/framework/consistent_tensor_infer_cache.h" #include "oneflow/core/eager/local_dep_object.h" +#include "oneflow/core/eager/critical_section_instruction_type.h" +#include "oneflow/core/eager/lazy_job_instruction_type.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/framework/stream.h" #include "oneflow/core/framework/stream_need_soft_sync.h" -#include "oneflow/core/framework/stream_get_call_instruction_name.h" -#include "oneflow/core/framework/stream_get_release_instruction_name.h" #include "oneflow/core/framework/stream_is_comm_net_stream.h" #include "oneflow/core/job/env_desc.h" #include "oneflow/core/profiler/profiler.h" @@ -57,24 +61,29 @@ Maybe> RawGetCriticalSectionStream() { static constexpr auto* GetCriticalSectionStream = DECORATE(&RawGetCriticalSectionStream, ThreadLocal); +Maybe> RawGetLazyJobLauncherStream() { + return Stream::New(JUST(Device::New("cpu")), StreamRole::kLazyJobLauncher); +} + +static constexpr auto* GetLazyJobLauncherStream = + DECORATE(&RawGetLazyJobLauncherStream, ThreadLocal); + } // namespace template Maybe InstructionsBuilder::MakeCriticalSectionBegin( - const std::shared_ptr& phy_instr_operand) { + vm::Stream* vm_stream, const std::shared_ptr& phy_instr_operand) { auto instruction = intrusive::make_shared( - Global::Get()->mut_vm(), "CriticalSectionBegin", - std::shared_ptr(), phy_instr_operand); + vm_stream, SingletonPtr(), phy_instr_operand); instruction_list_->EmplaceBack(std::move(instruction)); return Maybe::Ok(); } template Maybe InstructionsBuilder::MakeCriticalSectionEnd( - const std::shared_ptr& phy_instr_operand) { + vm::Stream* vm_stream, const std::shared_ptr& phy_instr_operand) { auto instruction = intrusive::make_shared( - Global::Get()->mut_vm(), "CriticalSectionEnd", - std::shared_ptr(), phy_instr_operand); + vm_stream, SingletonPtr(), phy_instr_operand); instruction_list_->EmplaceBack(std::move(instruction)); return Maybe::Ok(); } @@ -138,10 +147,13 @@ Maybe InstructionsBuilder::LaunchLazyJob(const one::EagerBlobObjectListPtr const auto& event_record = std::make_shared(); CHECK_OR_RETURN(input_op_name2end_event_record->emplace(op_name, event_record).second); } + + auto stream = JUST(GetCriticalSectionStream()); + auto* vm_stream = JUST(Global::Get()->GetVmStream(stream)); const auto& phy_instr_operand = std::make_shared( - nn_graph, inputs, input_op_name2end_event_record); - JUST(MakeCriticalSectionBegin(phy_instr_operand)); + nn_graph, inputs, input_op_name2end_event_record, vm_stream); + JUST(MakeCriticalSectionBegin(vm_stream, phy_instr_operand)); } const auto& output_op_name2end_event_record = std::make_shared>>(); @@ -150,34 +162,39 @@ Maybe InstructionsBuilder::LaunchLazyJob(const one::EagerBlobObjectListPtr const auto& event_record = std::make_shared(); CHECK_OR_RETURN(output_op_name2end_event_record->emplace(op_name, event_record).second); } + auto stream = JUST(GetCriticalSectionStream()); + auto* vm_stream = JUST(Global::Get()->GetVmStream(stream)); const auto& phy_instr_operand = std::make_shared( - nn_graph, outputs, output_op_name2end_event_record); - JUST(MakeCriticalSectionBegin(phy_instr_operand)); + nn_graph, outputs, output_op_name2end_event_record, vm_stream); + JUST(MakeCriticalSectionBegin(vm_stream, phy_instr_operand)); } { const auto& phy_instr_operand = std::make_shared(nn_graph, parameters); + auto stream = JUST(GetLazyJobLauncherStream()); + auto* vm_stream = JUST(Global::Get()->GetVmStream(stream)); auto instruction = intrusive::make_shared( - Global::Get()->mut_vm(), "LaunchLazyJob", - std::shared_ptr(), phy_instr_operand); + vm_stream, SingletonPtr(), phy_instr_operand); instruction_list_->EmplaceBack(std::move(instruction)); } + auto stream = JUST(GetCriticalSectionStream()); + auto* vm_stream = JUST(Global::Get()->GetVmStream(stream)); for (int i = 0; i < nn_graph->inputs_op_names().size(); ++i) { const auto& eager_blob_object = inputs->at(i); const auto& op_name = nn_graph->inputs_op_names().at(i); const auto& event_record = JUST(MapAt(*input_op_name2end_event_record, op_name)); const auto& phy_instr_operand = std::make_shared( - eager_blob_object, event_record); - JUST(MakeCriticalSectionEnd(phy_instr_operand)); + eager_blob_object, event_record, vm_stream); + JUST(MakeCriticalSectionEnd(vm_stream, phy_instr_operand)); } for (int i = 0; i < nn_graph->outputs_op_names().size(); ++i) { const auto& eager_blob_object = outputs->at(i); const auto& op_name = nn_graph->outputs_op_names().at(i); const auto& event_record = JUST(MapAt(*output_op_name2end_event_record, op_name)); const auto& phy_instr_operand = std::make_shared( - eager_blob_object, event_record); - JUST(MakeCriticalSectionEnd(phy_instr_operand)); + eager_blob_object, event_record, vm_stream); + JUST(MakeCriticalSectionEnd(vm_stream, phy_instr_operand)); } } return Maybe::Ok(); @@ -191,26 +208,29 @@ Maybe InstructionsBuilder::SoftSyncNNGraphBuffers( return Maybe::Ok(); } -Maybe InstructionsBuilder::CreateSymbolId() { return JUST(id_generator_->NewSymbolId()); } +namespace { + +int64_t NewSymbolId() { + static std::atomic cnt(0); + return cnt.fetch_add(1, std::memory_order_relaxed); +} + +} // namespace Maybe InstructionsBuilder::GetJobConfSymbol(const JobConfigProto& job_conf) { - return Global>::Get()->FindOrCreate( - job_conf, [&] { return this->CreateSymbolId(); }); + return Global>::Get()->FindOrCreate(job_conf, &NewSymbolId); } Maybe InstructionsBuilder::GetParallelDescSymbol(const ParallelConf& parallel_conf) { - return Global>::Get()->FindOrCreate( - parallel_conf, [&] { return this->CreateSymbolId(); }); + return Global>::Get()->FindOrCreate(parallel_conf, &NewSymbolId); } Maybe InstructionsBuilder::GetScopeSymbol(const ScopeProto& scope_proto) { - return Global>::Get()->FindOrCreate( - scope_proto, [&] { return this->CreateSymbolId(); }); + return Global>::Get()->FindOrCreate(scope_proto, &NewSymbolId); } Maybe InstructionsBuilder::GetOpConfSymbol(const OperatorConf& op_conf) { - return Global>::Get()->FindOrCreate( - op_conf, [&] { return this->CreateSymbolId(); }); + return Global>::Get()->FindOrCreate(op_conf, &NewSymbolId); } Maybe InstructionsBuilder::BuildInitialScope( @@ -337,32 +357,27 @@ Maybe InstructionsBuilder::BuildScopeByProtoStrSetter( return GetScopeSymbol(*scope_proto); } -Maybe InstructionsBuilder::LocalCallOpKernel( - const std::shared_ptr& opkernel, - const one::EagerBlobObjectListPtr& input_eager_blob_objects, - const one::EagerBlobObjectListPtr& output_eager_blob_objects, - const one::OpExprInterpContext& ctx, Symbol stream) { - return LocalCallOpKernel(opkernel, input_eager_blob_objects, output_eager_blob_objects, nullptr, - ctx, stream); +Maybe InstructionsBuilder::Call(const std::shared_ptr& opkernel, + const one::EagerBlobObjectListPtr& input_eager_blob_objects, + const one::EagerBlobObjectListPtr& output_eager_blob_objects, + const one::OpExprInterpContext& ctx, Symbol stream) { + return Call(opkernel, input_eager_blob_objects, output_eager_blob_objects, nullptr, ctx, stream); } -Maybe InstructionsBuilder::LocalCallOpKernel( - const std::shared_ptr& opkernel, +Maybe InstructionsBuilder::Call( + const std::shared_ptr& opkernel, const one::EagerBlobObjectListPtr& input_eager_blob_objects, const one::EagerBlobObjectListPtr& output_eager_blob_objects, const std::shared_ptr& consistent_tensor_infer_result, const one::OpExprInterpContext& ctx, Symbol stream) { - const auto& parallel_desc_sym = JUST(Placement4Device(stream->device())).shared_from_symbol(); JUST(SoftSyncStream(output_eager_blob_objects, stream)); JUST(SoftSyncStream(input_eager_blob_objects, stream)); - auto phy_instr_operand = JUST(vm::LocalCallOpKernelPhyInstrOperand::New( - opkernel, input_eager_blob_objects, output_eager_blob_objects, consistent_tensor_infer_result, - ctx, *one::CurrentDevVmDepObjectConsumeMode())); - const auto& instruction_name = JUST(StreamRoleSwitch( - stream->stream_role(), stream->device()->enum_type())); + auto* vm_stream = JUST(Global::Get()->GetVmStream(stream)); + auto phy_instr_operand = JUST(vm::OpCallPhyInstrOperand::New( + vm_stream, opkernel, input_eager_blob_objects, output_eager_blob_objects, + consistent_tensor_infer_result, ctx, *one::CurrentDevVmDepObjectConsumeMode())); auto instruction = intrusive::make_shared( - Global::Get()->mut_vm(), instruction_name, parallel_desc_sym, - phy_instr_operand); + vm_stream, SingletonPtr(), phy_instr_operand); instruction_list_->EmplaceBack(std::move(instruction)); for (const auto& output : *output_eager_blob_objects) { if (!output->producer_stream().has_value()) { JUST(output->init_producer_stream(stream)); } @@ -372,14 +387,13 @@ Maybe InstructionsBuilder::LocalCallOpKernel( } Maybe InstructionsBuilder::ReleaseTensor( - const std::shared_ptr& eager_blob_object, - const std::shared_ptr& parallel_desc) { - if (pthread_fork::IsForkedSubProcess() && parallel_desc - && parallel_desc->device_type() != DeviceType::kCPU) { - return Maybe::Ok(); - } + const std::shared_ptr& eager_blob_object) { const auto& last_used_stream = JUST(eager_blob_object->last_used_stream()); const auto& producer_stream = JUST(eager_blob_object->producer_stream()); + if (pthread_fork::IsForkedSubProcess() + && producer_stream->device()->enum_type() != DeviceType::kCPU) { + return Maybe::Ok(); + } if (last_used_stream != producer_stream) { JUST(SoftSyncStream({JUST(eager_blob_object->compute_local_dep_object())}, "mut", last_used_stream)); @@ -387,23 +401,26 @@ Maybe InstructionsBuilder::ReleaseTensor( Optional> stream{}; if (*one::CurrentDevVmDepObjectConsumeMode() == one::DevVmDepObjectConsumeMode::NONE) { stream = Optional>(NullOpt); - } else if (StreamRoleSwitch(last_used_stream->stream_role())) { + } else if (IsCommNetStream::Visit(last_used_stream->stream_role())) { // Disable inter-device instruction sequential for tensor used by communicative stream. // It's not acceptable for us that cuda compute stream is blocked by cuda nccl stream. stream = Optional>(NullOpt); - } else if (StreamRoleSwitch(producer_stream->stream_role())) { + } else if (IsCommNetStream::Visit(producer_stream->stream_role())) { // Disable inter-device instruction sequential for tensor produced by communicative stream. stream = Optional>(NullOpt); } else { stream = producer_stream; } + auto vm_stream = stream.map([](Symbol stream) -> vm::Stream* { + return CHECK_JUST(Global::Get()->GetVmStream(stream)); + }); const auto& phy_instr_operand = - std::make_shared(eager_blob_object, stream); + std::make_shared(eager_blob_object, vm_stream); + StreamRole stream_role = producer_stream->stream_role(); DeviceType device_type = producer_stream->device()->enum_type(); - const auto& instruction_name = JUST( - StreamRoleSwitch(producer_stream->stream_role(), device_type)); auto instruction = intrusive::make_shared( - Global::Get()->mut_vm(), instruction_name, parallel_desc, phy_instr_operand); + JUST(Global::Get()->GetVmStream(producer_stream)), + JUST(GetReleaseInstructionType::Visit(stream_role, device_type)), phy_instr_operand); instruction_list_->EmplaceBack(std::move(instruction)); return Maybe::Ok(); } @@ -435,39 +452,22 @@ Maybe InstructionsBuilder::SoftSyncStream( Maybe InstructionsBuilder::SoftSyncStream( std::vector>&& compute_local_dep_objects, - const std::string& modifier, Symbol stream) { - DeviceType device_type = stream->device()->enum_type(); - if (!StreamRoleSwitch(stream->stream_role(), device_type)) { + const std::string& modifier, Symbol last_used_stream) { + DeviceType device_type = last_used_stream->device()->enum_type(); + if (!NeedSoftSync::Visit(last_used_stream->stream_role(), device_type)) { return Maybe::Ok(); } OF_PROFILER_RANGE_GUARD("SoftStream"); - const auto& parallel_desc = JUST(Placement4Device(stream->device())).shared_from_symbol(); const auto& phy_instr_operand = std::make_shared( std::move(compute_local_dep_objects), modifier); + StreamRole stream_role = last_used_stream->stream_role(); auto instruction = intrusive::make_shared( - Global::Get()->mut_vm(), parallel_desc->device_tag() + ".RecordEvent", - parallel_desc, phy_instr_operand); + JUST(Global::Get()->GetVmStream(last_used_stream)), + JUST(GetRecordEventInstructionType::Visit(stream_role, device_type)), phy_instr_operand); instruction_list_->EmplaceBack(std::move(instruction)); return Maybe::Ok(); } -namespace { - -const std::shared_ptr& GetParallelDesc( - const std::shared_ptr tensor) { - const auto& device = CHECK_JUST(tensor->device()); - const auto& placement = CHECK_JUST(Placement4Device(device)); - return placement.shared_from_symbol(); -} - -const std::shared_ptr& GetParallelDesc( - const one::EagerMirroredTensorImpl* tensor) { - const auto& placement = CHECK_JUST(Placement4Device(tensor->device())); - return placement.shared_from_symbol(); -} - -} // namespace - template Maybe InstructionsBuilder::SyncAccessBlobByCallback( const T tensor, const std::shared_ptr& btb, @@ -520,17 +520,41 @@ template Maybe InstructionsBuilder::SyncAccessBlobByCallback( const one::EagerMirroredTensorImpl* tensor, const std::shared_ptr& btb, const std::function& Callback, const std::string& modifier); +namespace { + +Maybe> GetDevice(const std::shared_ptr& tensor) { + return tensor->device(); // return Maybe> +} + +Maybe> GetDevice(const one::EagerMirroredTensorImpl* tensor) { + return tensor->device(); // return const Symbol& +} + +} // namespace + template Maybe InstructionsBuilder::AccessBlobByCallback(const T tensor, const std::function& callback, const std::string& modifier) { - const auto& parallel_desc = GetParallelDesc(tensor); const std::shared_ptr& eager_blob_object = JUST(tensor->eager_blob_object()); const auto& phy_instr_operand = std::make_shared(eager_blob_object, callback, modifier); + Symbol device = JUST(GetDevice(tensor)); + Symbol stream = JUST(GetDefaultStreamByDevice(device)); + // Do not use producer_stream or last_used_stream. + // Bug case when using producer_stream or last_used_stream: + // + // ```python + // tensor = oneflow.ones((1024, 1024, 1024), device='cuda').cpu() + // ndarray = tensor.numpy() # share memory + // + // ``` + // `ndarray` may not be ones because instruction AccessBlobByCallback is prescheduled before + // oneflow.ones actually finished. auto instruction = intrusive::make_shared( - Global::Get()->mut_vm(), - parallel_desc->device_tag() + ".AccessBlobByCallback", parallel_desc, phy_instr_operand); + // Never replace `stream` with producer_stream or last_used_stream. + JUST(Global::Get()->GetVmStream(stream)), + SingletonPtr(), phy_instr_operand); instruction_list_->EmplaceBack(std::move(instruction)); return Maybe::Ok(); } @@ -543,29 +567,38 @@ template Maybe InstructionsBuilder::AccessBlobByCallback( const one::EagerMirroredTensorImpl* tensor, const std::function& callback, const std::string& modifier); -Maybe InstructionsBuilder::ComputeRankFrontSeqCallback( - const std::function& callback) { - const auto& phy_instr_operand = std::make_shared(callback); +namespace { + +Maybe> GetBarrierStream() { + auto device = JUST(Device::New("cpu")); + return Stream::New(device, StreamRole::kBarrier); +} + +} // namespace + +Maybe InstructionsBuilder::GlobalSync() { + const auto& phy_instr_operand = std::make_shared([]() {}); + auto stream = JUST(GetBarrierStream()); auto instruction = intrusive::make_shared( - Global::Get()->mut_vm(), "ComputeRankFrontSeqCallback", - std::shared_ptr(), phy_instr_operand); + JUST(Global::Get()->GetVmStream(stream)), + SingletonPtr(), phy_instr_operand); instruction_list_->PushBack(instruction.Mutable()); return Maybe::Ok(); } -Maybe InstructionsBuilder::ComputeGlobalFrontSeqBarrier() { - const auto& phy_instr_operand = std::make_shared([] {}); +Maybe InstructionsBuilder::Barrier(const std::function& Callback) { + const auto& phy_instr_operand = std::make_shared(Callback); + auto stream = JUST(GetBarrierStream()); auto instruction = intrusive::make_shared( - Global::Get()->mut_vm(), "ComputeGlobalFrontSeqBarrier", - std::shared_ptr(), phy_instr_operand); + JUST(Global::Get()->GetVmStream(stream)), + SingletonPtr(), phy_instr_operand); instruction_list_->PushBack(instruction.Mutable()); return Maybe::Ok(); } Maybe PhysicalRun(const std::function(InstructionsBuilder*)>& Build) { vm::InstructionMsgList instruction_list; - InstructionsBuilder instructions_builder(std::make_shared(), - &instruction_list); + InstructionsBuilder instructions_builder(&instruction_list); JUST(Build(&instructions_builder)); JUST(vm::Run(instructions_builder.mut_instruction_list())); return Maybe::Ok(); diff --git a/oneflow/core/framework/instructions_builder.h b/oneflow/core/framework/instructions_builder.h index 8bf70c203b8..ddbb017d986 100644 --- a/oneflow/core/framework/instructions_builder.h +++ b/oneflow/core/framework/instructions_builder.h @@ -16,10 +16,9 @@ limitations under the License. #ifndef ONEFLOW_CORE_FRAMEWORK_INSTRUCTIONS_BUILDER_H_ #define ONEFLOW_CORE_FRAMEWORK_INSTRUCTIONS_BUILDER_H_ -#include "oneflow/core/eager/local_call_opkernel_phy_instr_operand.h" +#include "oneflow/core/eager/op_call_phy_instr_operand.h" #include "oneflow/core/eager/lazy_job_phy_instr_operand.h" #include "oneflow/core/vm/instruction.h" -#include "oneflow/core/vm/id_generator.h" #include "oneflow/core/job/job_desc.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/job/scope.h" @@ -33,7 +32,7 @@ limitations under the License. namespace oneflow { namespace one { -class StatefulLocalOpKernel; +class StatefulOpKernel; class TensorTuple; class MirroredTensor; class ConsistentTensorInferResult; @@ -47,12 +46,10 @@ class InstructionsBuilder : public std::enable_shared_from_this& id_generator, - vm::InstructionMsgList* instruction_list) - : id_generator_(id_generator), instruction_list_(instruction_list) {} + explicit InstructionsBuilder(vm::InstructionMsgList* instruction_list) + : instruction_list_(instruction_list) {} ~InstructionsBuilder() { instruction_list_->Clear(); } - const std::shared_ptr& id_generator() const { return id_generator_; } const vm::InstructionMsgList& instruction_list() const { return *instruction_list_; } vm::InstructionMsgList* mut_instruction_list() { return instruction_list_; } @@ -67,8 +64,6 @@ class InstructionsBuilder : public std::enable_shared_from_this SoftSyncNNGraphBuffers(const one::EagerBlobObjectListPtr& eager_blob_objects, const std::shared_ptr& nn_graph); - Maybe CreateSymbolId(); - Maybe GetJobConfSymbol(const JobConfigProto& job_conf); Maybe GetParallelDescSymbol(const ParallelConf& parallel_conf); @@ -77,8 +72,7 @@ class InstructionsBuilder : public std::enable_shared_from_this GetOpConfSymbol(const OperatorConf& op_conf); - Maybe ReleaseTensor(const std::shared_ptr& eager_blob_object, - const std::shared_ptr& parallel_desc); + Maybe ReleaseTensor(const std::shared_ptr& eager_blob_object); template Maybe SyncAccessBlobByCallback(const T tensor, const std::shared_ptr& btb, @@ -89,9 +83,8 @@ class InstructionsBuilder : public std::enable_shared_from_this AccessBlobByCallback(const T tensor, const std::function& callback, const std::string& modifier); - Maybe ComputeRankFrontSeqCallback(const std::function& callback); - - Maybe ComputeGlobalFrontSeqBarrier(); + Maybe GlobalSync(); + Maybe Barrier(const std::function& callback); Maybe BuildInitialScope(int64_t session_id, const JobConfigProto& job_conf, const std::string& device_tag, @@ -122,13 +115,13 @@ class InstructionsBuilder : public std::enable_shared_from_this& scope, const std::function& StrSetter); - Maybe LocalCallOpKernel(const std::shared_ptr& opkernel, - const one::EagerBlobObjectListPtr& input_eager_blob_objects, - const one::EagerBlobObjectListPtr& output_eager_blob_objects, - const one::OpExprInterpContext& ctx, Symbol stream); + Maybe Call(const std::shared_ptr& opkernel, + const one::EagerBlobObjectListPtr& input_eager_blob_objects, + const one::EagerBlobObjectListPtr& output_eager_blob_objects, + const one::OpExprInterpContext& ctx, Symbol stream); - Maybe LocalCallOpKernel( - const std::shared_ptr& opkernel, + Maybe Call( + const std::shared_ptr& opkernel, const one::EagerBlobObjectListPtr& input_eager_blob_objects, const one::EagerBlobObjectListPtr& output_eager_blob_objects, const std::shared_ptr& consistent_tensor_infer_result, @@ -141,16 +134,15 @@ class InstructionsBuilder : public std::enable_shared_from_this>&& compute_local_dep_objects, const std::string& modifier, Symbol stream); - vm::IdGenerator* mut_id_generator() { return id_generator_.get(); } - private: template - Maybe MakeCriticalSectionBegin(const std::shared_ptr& phy_instr_operand); + Maybe MakeCriticalSectionBegin(vm::Stream* vm_stream, + const std::shared_ptr& phy_instr_operand); template - Maybe MakeCriticalSectionEnd(const std::shared_ptr& phy_instr_operand); + Maybe MakeCriticalSectionEnd(vm::Stream* vm_stream, + const std::shared_ptr& phy_instr_operand); - std::shared_ptr id_generator_; vm::InstructionMsgList* instruction_list_; }; diff --git a/oneflow/core/framework/op_expr.cpp b/oneflow/core/framework/op_expr.cpp index 916c049728e..27e4f65b55a 100644 --- a/oneflow/core/framework/op_expr.cpp +++ b/oneflow/core/framework/op_expr.cpp @@ -24,7 +24,7 @@ limitations under the License. #include "oneflow/core/framework/user_op_registry_manager.h" #include "oneflow/core/framework/consistent_tensor_infer_cache.h" #include "oneflow/core/operator/op_conf.pb.h" -#include "oneflow/user/kernels/stateful_local_opkernel.h" +#include "oneflow/user/kernels/stateful_opkernel.h" namespace oneflow { namespace one { @@ -122,7 +122,7 @@ Maybe BuiltinOpExprImpl::BuildOpConf(OperatorConf* op_conf, return Maybe::Ok(); } -Maybe UserOpExpr::MutKernel4Stream(Symbol stream) const { +Maybe UserOpExpr::MutKernel4Stream(Symbol stream) const { const auto& it = stream2kernel_.find(stream); if (it != stream2kernel_.end()) { return it->second; } @@ -130,8 +130,8 @@ Maybe UserOpExpr::MutKernel4Stream(Symbol stream) JUST(BuildOpConf(op_conf.get(), {})); op_conf->set_device_tag(stream->device()->type()); auto parallel_desc = JUST(Placement4Device(stream->device())).shared_from_symbol(); - const auto& opkernel = JUST(StatefulLocalOpKernel::New( - op_conf, stream, base_attrs(), parallel_desc, input_arg_tuple(), output_arg_tuple())); + const auto& opkernel = JUST(StatefulOpKernel::New(op_conf, stream, base_attrs(), parallel_desc, + input_arg_tuple(), output_arg_tuple())); stream2kernel_.emplace(stream, opkernel); return opkernel; } diff --git a/oneflow/core/framework/op_expr.h b/oneflow/core/framework/op_expr.h index 5f76213a687..3806724c408 100644 --- a/oneflow/core/framework/op_expr.h +++ b/oneflow/core/framework/op_expr.h @@ -125,7 +125,7 @@ class BuiltinOpExprImpl : public BuiltinOpExpr { mutable std::shared_ptr op_grad_func_; }; -class StatefulLocalOpKernel; +class StatefulOpKernel; class ConsistentTensorInferCache; class UserOpExpr final : public BuiltinOpExprImpl { @@ -139,7 +139,7 @@ class UserOpExpr final : public BuiltinOpExprImpl { const AttrMap& base_attrs() const { return base_attrs_; } - Maybe MutKernel4Stream(Symbol stream) const; + Maybe MutKernel4Stream(Symbol stream) const; bool has_device_and_stream_infer_fn() const { return static_cast(device_and_stream_infer_fn_); @@ -172,7 +172,7 @@ class UserOpExpr final : public BuiltinOpExprImpl { user_op::TensorDescInferFn tensor_desc_infer_fn_; user_op::DataTypeInferFn dtype_infer_fn_; user_op::DeviceAndStreamInferFn device_and_stream_infer_fn_; - mutable HashMap, std::shared_ptr> stream2kernel_; + mutable HashMap, std::shared_ptr> stream2kernel_; std::shared_ptr consistent_tensor_infer_cache_; }; diff --git a/oneflow/core/framework/op_interpreter/eager_consistent_op_interpreter.cpp b/oneflow/core/framework/op_interpreter/eager_consistent_op_interpreter.cpp index c72f1a764ac..4c71d4f7300 100644 --- a/oneflow/core/framework/op_interpreter/eager_consistent_op_interpreter.cpp +++ b/oneflow/core/framework/op_interpreter/eager_consistent_op_interpreter.cpp @@ -29,7 +29,7 @@ limitations under the License. #include "oneflow/core/operator/operator.h" #include "oneflow/core/autograd/autograd_mode.h" #include "oneflow/core/boxing/eager_boxing_interpreter_mgr.h" -#include "oneflow/user/kernels/stateful_local_opkernel.h" +#include "oneflow/user/kernels/stateful_opkernel.h" #include "oneflow/core/framework/consistency_check.h" #include "oneflow/core/framework/tensor_rpc_util.h" #include "oneflow/core/framework/tensor_consistent_id.h" @@ -50,7 +50,7 @@ Maybe> GetParallelDesc(const TensorTuple& inputs, } std::string GetDynamicOpConsistentFailedDebugString(const UserOpExpr& user_op_expr, - const StatefulLocalOpKernel& kernel) { + const StatefulOpKernel& kernel) { CHECK(!kernel.output_tuple_indexes4mut2_obns().empty()); std::string plentysuffix = kernel.output_tuple_indexes4mut2_obns().size() == 1 ? "s" : ""; std::stringstream ss; @@ -147,7 +147,7 @@ Maybe Interpret(const UserOpExpr& user_op_expr, const TensorTuple& inputs, if (unlikely(JUST(CachedIsAllZeroSizeTensorMeta(output_tensor_metas)))) { return Maybe::Ok(); } - // Run instruction LocalCallOpKernel + // Run instruction Call const auto& kernel = JUST(user_op_expr.MutKernel4Stream(result->stream())); CHECK_EQ_OR_RETURN(kernel->output_tuple_indexes4mut2_obns().size(), 0) << Error::UnimplementedError() @@ -179,8 +179,8 @@ Maybe Interpret(const UserOpExpr& user_op_expr, const TensorTuple& inputs, output_eager_blob_objects->at(i) = JUST(local_tensor->eager_blob_object()); } JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe { - return builder->LocalCallOpKernel(kernel, input_eager_blob_objects, output_eager_blob_objects, - result, ctx, result->stream()); + return builder->Call(kernel, input_eager_blob_objects, output_eager_blob_objects, result, ctx, + result->stream()); })); return Maybe::Ok(); } diff --git a/oneflow/core/framework/op_interpreter/eager_mirrored_op_interpreter.cpp b/oneflow/core/framework/op_interpreter/eager_mirrored_op_interpreter.cpp index 8034dbfefb4..39353714be1 100644 --- a/oneflow/core/framework/op_interpreter/eager_mirrored_op_interpreter.cpp +++ b/oneflow/core/framework/op_interpreter/eager_mirrored_op_interpreter.cpp @@ -29,7 +29,7 @@ limitations under the License. #include "oneflow/core/common/stride.h" #include "oneflow/core/memory/memory_case_util.h" #include "oneflow/core/operator/operator.h" -#include "oneflow/user/kernels/stateful_local_opkernel.h" +#include "oneflow/user/kernels/stateful_opkernel.h" #include "oneflow/core/vm/vm_util.h" #include "oneflow/core/autograd/autograd_mode.h" #include "oneflow/core/framework/placement_sbp_util.h" @@ -119,7 +119,7 @@ Maybe NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& in // Infer devices if (!user_op_expr.has_device_and_stream_infer_fn()) { - stream = GetDefaultStreamByDevice(default_device); + stream = JUST(GetDefaultStreamByDevice(default_device)); for (int i = 0; i < outputs->size(); i++) { auto* tensor_impl = JUST(TensorImpl4Tensor(outputs->at(i))); *JUST(tensor_impl->mut_device()) = default_device; @@ -175,8 +175,7 @@ Maybe NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& in } JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe { - return builder->LocalCallOpKernel(kernel, input_eager_blob_objects, output_eager_blob_objects, - ctx, stream); + return builder->Call(kernel, input_eager_blob_objects, output_eager_blob_objects, ctx, stream); })); return Maybe::Ok(); } diff --git a/oneflow/core/framework/stream.cpp b/oneflow/core/framework/stream.cpp index c10bf0cf4fa..ba9facf5b6f 100644 --- a/oneflow/core/framework/stream.cpp +++ b/oneflow/core/framework/stream.cpp @@ -17,49 +17,37 @@ limitations under the License. #include "oneflow/core/framework/stream_is_comm_net_stream.h" #include "oneflow/core/common/decorator.h" #include "oneflow/core/common/static_global.h" +#include "oneflow/core/common/global.h" #include "oneflow/core/job/parallel_desc.h" -#include "oneflow/core/vm/vm_object.h" -#include "oneflow/core/intrusive/intrusive.h" +#include "oneflow/core/framework/stream_mgr.h" namespace oneflow { -namespace { - -intrusive::shared_ptr RawGetStaticGlobalTransportLocalDepObject() { - return intrusive::make_shared(); -} +Stream::Stream(Symbol device, StreamRole stream_role) + : device_(device), stream_role_(stream_role), unique_stream_id_(-1) {} -intrusive::shared_ptr RawNewComputeDepObject(Symbol, StreamRole) { - return intrusive::make_shared(); +Maybe Stream::Init(size_t unique_stream_id) { + unique_stream_id_ = unique_stream_id; + return Maybe::Ok(); } -} // namespace - -LocalDepObject* GetStaticGlobalTransportLocalDepObject() { - static constexpr auto* GetLocalDepObject = - DECORATE(&RawGetStaticGlobalTransportLocalDepObject, StaticGlobalCopiable); - return GetLocalDepObject().Mutable(); +/*static*/ Maybe> Stream::RawNew(Symbol device, StreamRole stream_role) { + std::shared_ptr stream(new Stream(device, stream_role)); + return JUST(GlobalMaybe()) + ->AddStreamSymbol(*stream, [&](size_t unique_stream_id) -> Maybe> { + JUST(stream->Init(unique_stream_id)); + return SymbolOf(*stream); + }); } -Stream::Stream(Symbol device, StreamRole stream_role) - : device_(device), - stream_role_(stream_role), - schedule_local_dep_object_(nullptr), - transport_local_dep_object_(NullOpt) { - static constexpr auto* GetComputeDep = DECORATE(&RawNewComputeDepObject, StaticGlobalCopiable); - schedule_local_dep_object_ = GetComputeDep(device, stream_role).Mutable(); - if (StreamRoleSwitch(stream_role)) { - transport_local_dep_object_ = GetStaticGlobalTransportLocalDepObject(); - } +/*static*/ Maybe> Stream::New(Symbol device, StreamRole stream_role) { + constexpr auto* Make = DECORATE(&Stream::RawNew, ThreadLocal); + return Make(device, stream_role); } namespace { -Symbol RawNewStream(Symbol device, StreamRole stream_role) { - return SymbolOf(Stream(device, stream_role)); -} - -Symbol RawGetDefaultStreamByDevice(Symbol device) { +Maybe> RawGetDefaultStreamByDevice(Symbol device) { return Stream::New(device, StreamRole::kCompute); } @@ -69,8 +57,6 @@ Maybe> RawGetDefaultStreamByPlacement(Symbol parall } // namespace -decltype(Stream::New) Stream::New = DECORATE(&RawNewStream, ThreadLocal); - decltype(GetDefaultStreamByDevice) GetDefaultStreamByDevice = DECORATE(&RawGetDefaultStreamByDevice, ThreadLocal); diff --git a/oneflow/core/framework/stream.h b/oneflow/core/framework/stream.h index 52af85eb9d5..e851eb1e8e6 100644 --- a/oneflow/core/framework/stream.h +++ b/oneflow/core/framework/stream.h @@ -25,11 +25,6 @@ limitations under the License. namespace oneflow { -namespace vm { -class MirroredObject; -} -using LocalDepObject = vm::MirroredObject; - class Stream final { public: Stream(const Stream&) = default; @@ -41,29 +36,25 @@ class Stream final { } bool operator!=(const Stream& that) const { return !(*this == that); } - Stream(Symbol device, StreamRole stream_role); - - static Symbol (*New)(Symbol device, StreamRole stream_role); + static Maybe> New(Symbol device, StreamRole stream_role); Symbol device() const { return device_; } StreamRole stream_role() const { return stream_role_; } - - LocalDepObject* mut_schedule_local_dep_object() const { return schedule_local_dep_object_; } - const Optional& mut_transport_local_dep_object() const { - return transport_local_dep_object_; - } + size_t unique_stream_id() const { return unique_stream_id_; } private: + Stream(Symbol device, StreamRole stream_role); + + static Maybe> RawNew(Symbol device, StreamRole stream_role); + + Maybe Init(size_t unique_stream_id); + Symbol device_; StreamRole stream_role_; - - LocalDepObject* schedule_local_dep_object_; - Optional transport_local_dep_object_; + size_t unique_stream_id_; }; -LocalDepObject* GetStaticGlobalTransportLocalDepObject(); - -extern Symbol (*GetDefaultStreamByDevice)(Symbol); +extern Maybe> (*GetDefaultStreamByDevice)(Symbol); class ParallelDesc; extern Maybe> (*GetDefaultStreamByPlacement)(Symbol); diff --git a/oneflow/core/framework/stream_get_call_instruction_name.h b/oneflow/core/framework/stream_get_call_instruction_name.h deleted file mode 100644 index 774a3e2aaff..00000000000 --- a/oneflow/core/framework/stream_get_call_instruction_name.h +++ /dev/null @@ -1,99 +0,0 @@ -/* -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -#ifndef ONEFLOW_CORE_FRAMEWORK_STREAM_GET_CALL_INSTRUCTION_NAME_H_ -#define ONEFLOW_CORE_FRAMEWORK_STREAM_GET_CALL_INSTRUCTION_NAME_H_ - -#include -#include -#include "oneflow/core/common/stream_role.h" -#include "oneflow/core/common/device_type.h" -#include "oneflow/core/common/maybe.h" -#include "oneflow/core/framework/to_string.h" - -namespace oneflow { - -struct GetCallInstructionName { - static Maybe Case(StreamRoleCase, - DeviceType device_type) { // NOLINT - static constexpr auto* Get = DECORATE(&Call::Invalid, ThreadLocal); - return *JUST(Get(device_type)); - } - static Maybe Case(StreamRoleCase, - DeviceType device_type) { - static constexpr auto* Get = DECORATE(&Call::Compute, ThreadLocal); - return *JUST(Get(device_type)); - } - static Maybe Case(StreamRoleCase, - DeviceType device_type) { - static constexpr auto* Get = DECORATE(&Call::Host2Device, ThreadLocal); - return *JUST(Get(device_type)); - } - static Maybe Case(StreamRoleCase, - DeviceType device_type) { - static constexpr auto* Get = DECORATE(&Call::Device2Host, ThreadLocal); - return *JUST(Get(device_type)); - } - static Maybe Case(StreamRoleCase, - DeviceType device_type) { - static constexpr auto* Get = DECORATE(&Call::SyncedLaunchedCommNet, ThreadLocal); - return *JUST(Get(device_type)); - } - static Maybe Case(StreamRoleCase, - DeviceType device_type) { - static constexpr auto* Get = DECORATE(&Call::AsyncedLaunchedCommNet, ThreadLocal); - return *JUST(Get(device_type)); - } - static Maybe Case(StreamRoleCase, - DeviceType device_type) { - static constexpr auto* Get = DECORATE(&Call::CriticalSection, ThreadLocal); - return *JUST(Get(device_type)); - } - - private: - struct Call { - static Maybe Invalid(DeviceType device_type) { // NOLINT - UNIMPLEMENTED_THEN_RETURN(); - } - static Maybe Compute(DeviceType device_type) { - return *JUST(DeviceTag4DeviceType(device_type)) + ".LocalCallOpKernel"; - } - static Maybe Host2Device(DeviceType device_type) { - CHECK_EQ_OR_RETURN(device_type, kCUDA); - return std::string("cuda_h2d.LocalCallOpKernel"); - } - static Maybe Device2Host(DeviceType device_type) { - CHECK_EQ_OR_RETURN(device_type, kCUDA); - return std::string("cuda_d2h.LocalCallOpKernel"); - } - static Maybe SyncedLaunchedCommNet(DeviceType device_type) { - if (device_type == kCPU) { return std::string("cpu.LocalCallOpKernel"); } - CHECK_EQ_OR_RETURN(device_type, kCUDA); - return std::string("cuda.LocalCallOpKernel"); - } - static Maybe AsyncedLaunchedCommNet(DeviceType device_type) { - if (device_type == kCPU) { return std::string("cpu.LocalCallOpKernel"); } - CHECK_EQ_OR_RETURN(device_type, kCUDA); - return std::string("async.cuda.LocalCallOpKernel"); - } - static Maybe CriticalSection(DeviceType device_type) { - UNIMPLEMENTED_THEN_RETURN(); - } - }; -}; - -} // namespace oneflow - -#endif // ONEFLOW_CORE_FRAMEWORK_STREAM_GET_CALL_INSTRUCTION_NAME_H_ diff --git a/oneflow/core/framework/stream_get_release_instruction_name.h b/oneflow/core/framework/stream_get_release_instruction_name.h deleted file mode 100644 index 262da8c29cc..00000000000 --- a/oneflow/core/framework/stream_get_release_instruction_name.h +++ /dev/null @@ -1,99 +0,0 @@ -/* -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -#ifndef ONEFLOW_CORE_FRAMEWORK_STREAM_GET_RELEASE_INSTRUCTION_NAME_H_ -#define ONEFLOW_CORE_FRAMEWORK_STREAM_GET_RELEASE_INSTRUCTION_NAME_H_ - -#include -#include -#include "oneflow/core/common/stream_role.h" -#include "oneflow/core/common/device_type.h" -#include "oneflow/core/common/maybe.h" -#include "oneflow/core/framework/to_string.h" - -namespace oneflow { - -struct GetReleaseInstructionName { - static Maybe Case(StreamRoleCase, - DeviceType device_type) { // NOLINT - static constexpr auto* Get = DECORATE(&Call::Invalid, ThreadLocal); - return *JUST(Get(device_type)); - } - static Maybe Case(StreamRoleCase, - DeviceType device_type) { - static constexpr auto* Get = DECORATE(&Call::Compute, ThreadLocal); - return *JUST(Get(device_type)); - } - static Maybe Case(StreamRoleCase, - DeviceType device_type) { - static constexpr auto* Get = DECORATE(&Call::Host2Device, ThreadLocal); - return *JUST(Get(device_type)); - } - static Maybe Case(StreamRoleCase, - DeviceType device_type) { - static constexpr auto* Get = DECORATE(&Call::Device2Host, ThreadLocal); - return *JUST(Get(device_type)); - } - static Maybe Case(StreamRoleCase, - DeviceType device_type) { - static constexpr auto* Get = DECORATE(&Call::SyncedLaunchedCommNet, ThreadLocal); - return *JUST(Get(device_type)); - } - static Maybe Case(StreamRoleCase, - DeviceType device_type) { - static constexpr auto* Get = DECORATE(&Call::AsyncedLaunchedCommNet, ThreadLocal); - return *JUST(Get(device_type)); - } - static Maybe Case(StreamRoleCase, - DeviceType device_type) { - static constexpr auto* Get = DECORATE(&Call::CriticalSection, ThreadLocal); - return *JUST(Get(device_type)); - } - - private: - struct Call { - static Maybe Invalid(DeviceType device_type) { // NOLINT - UNIMPLEMENTED_THEN_RETURN(); - } - static Maybe Compute(DeviceType device_type) { - return *JUST(DeviceTag4DeviceType(device_type)) + ".ReleaseTensor"; - } - static Maybe Host2Device(DeviceType device_type) { - CHECK_EQ_OR_RETURN(device_type, kCUDA); - return std::string("cuda_h2d.ReleaseTensor"); - } - static Maybe Device2Host(DeviceType device_type) { - CHECK_EQ_OR_RETURN(device_type, kCUDA); - return std::string("cuda_d2h.ReleaseTensor"); - } - static Maybe SyncedLaunchedCommNet(DeviceType device_type) { - if (device_type == kCPU) { return std::string("comm_net.ReleaseTensor"); } - CHECK_EQ_OR_RETURN(device_type, kCUDA); - return std::string("sync_launched_nccl.ReleaseTensor"); - } - static Maybe AsyncedLaunchedCommNet(DeviceType device_type) { - if (device_type == kCPU) { return std::string("comm_net.ReleaseTensor"); } - CHECK_EQ_OR_RETURN(device_type, kCUDA); - return std::string("async_launched_nccl.ReleaseTensor"); - } - static Maybe CriticalSection(DeviceType device_type) { - UNIMPLEMENTED_THEN_RETURN(); - } - }; -}; - -} // namespace oneflow - -#endif // ONEFLOW_CORE_FRAMEWORK_STREAM_GET_RELEASE_INSTRUCTION_NAME_H_ diff --git a/oneflow/core/framework/stream_get_stream_role_name.h b/oneflow/core/framework/stream_get_stream_role_name.h new file mode 100644 index 00000000000..b87148b2d6d --- /dev/null +++ b/oneflow/core/framework/stream_get_stream_role_name.h @@ -0,0 +1,40 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_CORE_FRAMEWORK_STREAM_GET_STREAM_ROLE_NAME_H_ +#define ONEFLOW_CORE_FRAMEWORK_STREAM_GET_STREAM_ROLE_NAME_H_ + +#include +#include +#include "oneflow/core/common/stream_role.h" +#include "oneflow/core/common/device_type.h" +#include "oneflow/core/framework/to_string.h" + +namespace oneflow { + +struct GetStreamRoleName : public StreamRoleVisitor { + static const char* VisitCompute() { return "compute"; } + static const char* VisitHost2Device() { return "h2d"; } + static const char* VisitDevice2Host() { return "d2h"; } + static const char* VisitSyncedLaunchedCommNet() { return "synced_launched_comm_net"; } + static const char* VisitAsyncedLaunchedCommNet() { return "asynced_launched_comm_net"; } + static const char* VisitBarrier() { return "barrier"; } + static const char* VisitCriticalSection() { return "critical_section"; } + static const char* VisitLazyJobLauncher() { return "lazy_job_launcher"; } +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_FRAMEWORK_STREAM_GET_STREAM_ROLE_NAME_H_ diff --git a/oneflow/core/framework/stream_is_comm_net_stream.h b/oneflow/core/framework/stream_is_comm_net_stream.h index c60906c7ff1..ccc231948f1 100644 --- a/oneflow/core/framework/stream_is_comm_net_stream.h +++ b/oneflow/core/framework/stream_is_comm_net_stream.h @@ -21,16 +21,15 @@ limitations under the License. namespace oneflow { -struct IsCommNetStream { - static bool Case(StreamRoleCase) { // NOLINT - LOG(FATAL); - } - static bool Case(StreamRoleCase) { return false; } - static bool Case(StreamRoleCase) { return false; } - static bool Case(StreamRoleCase) { return false; } - static bool Case(StreamRoleCase) { return true; } - static bool Case(StreamRoleCase) { return true; } - static bool Case(StreamRoleCase) { return false; } +struct IsCommNetStream final : public StreamRoleVisitor { + static bool VisitCompute() { return false; } + static bool VisitHost2Device() { return false; } + static bool VisitDevice2Host() { return false; } + static bool VisitSyncedLaunchedCommNet() { return true; } + static bool VisitAsyncedLaunchedCommNet() { return true; } + static bool VisitBarrier() { return false; } + static bool VisitCriticalSection() { return false; } + static bool VisitLazyJobLauncher() { return false; } }; } // namespace oneflow diff --git a/oneflow/core/framework/stream_mgr.cpp b/oneflow/core/framework/stream_mgr.cpp new file mode 100644 index 00000000000..4c1e44ec85e --- /dev/null +++ b/oneflow/core/framework/stream_mgr.cpp @@ -0,0 +1,61 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/stream_mgr.h" +#include "oneflow/core/common/container_util.h" +#include "oneflow/core/common/global.h" +#include "oneflow/core/common/util.h" + +namespace oneflow { + +Maybe> StreamMgr::AddStreamSymbol( + const Stream& stream, + const std::function>(size_t unique_stream_id)>& CreateStreamSymbol) { + Symbol stream_symbol; + std::unique_lock lock(mutex_); + if (stream2unique_stream_id_.count(stream) > 0) { + size_t unique_stream_id = stream2unique_stream_id_[stream]; + auto existed_stream_symbol = JUST(VectorAt(unique_stream_id2stream_symbol_, unique_stream_id)); + stream_symbol = JUST(CreateStreamSymbol(unique_stream_id)); + CHECK_OR_RETURN(existed_stream_symbol == stream_symbol) + << "the result of current called CreateStreamSymbol is not the result of last called " + "CreateStreamSymbol"; + } else { + size_t unique_stream_id = unique_stream_id2stream_symbol_.size(); + stream2unique_stream_id_[stream] = unique_stream_id; + stream_symbol = JUST(CreateStreamSymbol(unique_stream_id)); + unique_stream_id2stream_symbol_.push_back(stream_symbol); + CHECK_OR_RETURN(unique_stream_id2stream_symbol_[unique_stream_id] == stream) + << "the result of CreateStreamSymbol is no the symbol of `stream`"; + CHECK_EQ_OR_RETURN(unique_stream_id2stream_symbol_[unique_stream_id]->unique_stream_id(), + unique_stream_id) + << "unique_stream_id is wrongly initialized"; + } + return stream_symbol; +} + +size_t StreamMgr::UniqueStreamSize() const { + std::unique_lock lock(mutex_); + return unique_stream_id2stream_symbol_.size(); +} + +Maybe> StreamMgr::GetStreamSymbol(size_t unique_stream_id) const { + std::unique_lock lock(mutex_); + return JUST(VectorAt(unique_stream_id2stream_symbol_, unique_stream_id)); +} + +COMMAND(Global::SetAllocated(new StreamMgr())); + +} // namespace oneflow diff --git a/oneflow/core/framework/stream_mgr.h b/oneflow/core/framework/stream_mgr.h new file mode 100644 index 00000000000..a38ee2b183e --- /dev/null +++ b/oneflow/core/framework/stream_mgr.h @@ -0,0 +1,48 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_CORE_FRAMEWORK_STREAM_MGR_H_ +#define ONEFLOW_CORE_FRAMEWORK_STREAM_MGR_H_ + +#include +#include +#include "oneflow/core/common/symbol.h" +#include "oneflow/core/common/optional.h" +#include "oneflow/core/framework/stream.h" + +namespace oneflow { + +class StreamMgr final { + public: + StreamMgr() = default; + ~StreamMgr() = default; + + Maybe> AddStreamSymbol( + const Stream& stream, + const std::function>(size_t unique_stream_id)>& CreateStreamSymbol); + + size_t UniqueStreamSize() const; + + Maybe> GetStreamSymbol(size_t unique_stream_id) const; + + private: + mutable std::mutex mutex_; + std::vector> unique_stream_id2stream_symbol_; + std::unordered_map stream2unique_stream_id_; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_FRAMEWORK_STREAM_MGR_H_ diff --git a/oneflow/core/framework/stream_need_soft_sync.h b/oneflow/core/framework/stream_need_soft_sync.h index d783c8f4d2c..35dcb71fd30 100644 --- a/oneflow/core/framework/stream_need_soft_sync.h +++ b/oneflow/core/framework/stream_need_soft_sync.h @@ -22,22 +22,15 @@ limitations under the License. namespace oneflow { -struct NeedSoftSync { - static bool Case(StreamRoleCase, DeviceType) { // NOLINT - LOG(FATAL); - } - static bool Case(StreamRoleCase, DeviceType device_type) { - return device_type != kCPU; - } - static bool Case(StreamRoleCase, DeviceType) { return false; } - static bool Case(StreamRoleCase, DeviceType) { return false; } - static bool Case(StreamRoleCase, DeviceType device_type) { - return device_type != kCPU; - } - static bool Case(StreamRoleCase, DeviceType) { - return false; - } - static bool Case(StreamRoleCase, DeviceType) { return false; } +struct NeedSoftSync : public StreamRoleVisitor { + static bool VisitCompute(DeviceType device_type) { return device_type != kCPU; } + static bool VisitHost2Device(DeviceType) { return false; } + static bool VisitDevice2Host(DeviceType) { return false; } + static bool VisitSyncedLaunchedCommNet(DeviceType device_type) { return device_type != kCPU; } + static bool VisitAsyncedLaunchedCommNet(DeviceType) { return false; } + static bool VisitBarrier(DeviceType) { return false; } + static bool VisitCriticalSection(DeviceType) { return false; } + static bool VisitLazyJobLauncher(DeviceType) { return false; } }; } // namespace oneflow diff --git a/oneflow/core/framework/stream_on_independent_thread.h b/oneflow/core/framework/stream_on_independent_thread.h new file mode 100644 index 00000000000..54795a6f746 --- /dev/null +++ b/oneflow/core/framework/stream_on_independent_thread.h @@ -0,0 +1,37 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_CORE_FRAMEWORK_STREAM_ON_INDEPENDENT_THREAD_H_ +#define ONEFLOW_CORE_FRAMEWORK_STREAM_ON_INDEPENDENT_THREAD_H_ + +#include +#include "oneflow/core/common/stream_role.h" + +namespace oneflow { + +struct StreamOnIndependentThread : public StreamRoleVisitor { + static bool VisitCompute() { return false; } + static bool VisitHost2Device() { return false; } + static bool VisitDevice2Host() { return false; } + static bool VisitSyncedLaunchedCommNet() { return false; } + static bool VisitAsyncedLaunchedCommNet() { return false; } + static bool VisitBarrier() { return false; } + static bool VisitCriticalSection() { return true; } + static bool VisitLazyJobLauncher() { return true; } +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_FRAMEWORK_STREAM_ON_INDEPENDENT_THREAD_H_ diff --git a/oneflow/core/framework/tensor_consistent_id.cpp b/oneflow/core/framework/tensor_consistent_id.cpp index bcaf69e4142..f004f81c464 100644 --- a/oneflow/core/framework/tensor_consistent_id.cpp +++ b/oneflow/core/framework/tensor_consistent_id.cpp @@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "oneflow/core/common/decorator.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/framework/transport_token.h" diff --git a/oneflow/core/framework/tensor_impl.cpp b/oneflow/core/framework/tensor_impl.cpp index 558b57a72c1..832fc8b4d8d 100644 --- a/oneflow/core/framework/tensor_impl.cpp +++ b/oneflow/core/framework/tensor_impl.cpp @@ -83,12 +83,11 @@ EagerMirroredTensorImpl::EagerMirroredTensorImpl( Maybe EagerMirroredTensorImpl::UpdateTensorStorage() { const auto& eager_blob_object = eager_blob_object_; tensor_storage_ = std::make_shared(eager_blob_object->tensor_storage()); - const auto& parallel_desc = JUST(Placement4Device(this->device())).shared_from_symbol(); tensor_storage_->set_releaser_hook( - [eager_blob_object, parallel_desc](const std::shared_ptr&) { + [eager_blob_object](const std::shared_ptr&) { CHECK_JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe { if (eager_blob_object->producer_stream().has_value()) { - JUST(builder->ReleaseTensor(eager_blob_object, parallel_desc)); + JUST(builder->ReleaseTensor(eager_blob_object)); } return Maybe::Ok(); })); diff --git a/oneflow/core/vm/barrier_instruction_type.h b/oneflow/core/vm/barrier_instruction_type.h new file mode 100644 index 00000000000..f6f3e20edc2 --- /dev/null +++ b/oneflow/core/vm/barrier_instruction_type.h @@ -0,0 +1,66 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_CORE_VM_BARRIER_INSTRUCTION_TYPE_H_ +#define ONEFLOW_CORE_VM_BARRIER_INSTRUCTION_TYPE_H_ + +#include "oneflow/core/common/util.h" +#include "oneflow/core/intrusive/flat_msg_view.h" +#include "oneflow/core/rpc/include/base.h" +#include "oneflow/core/vm/control_stream_type.h" +#include "oneflow/core/vm/instruction_type.h" +#include "oneflow/core/vm/instruction.h" +#include "oneflow/core/vm/virtual_machine_engine.h" +#include "oneflow/core/vm/barrier_phy_instr_operand.h" +#include "oneflow/core/control/global_process_ctx.h" + +namespace oneflow { +namespace vm { + +class BarrierInstructionType : public InstructionType { + public: + BarrierInstructionType() = default; + virtual ~BarrierInstructionType() override = default; + + bool IsBarrier() const override { return true; } + + std::string DebugName(const vm::InstructionMsg& instr_msg) const override { return "Barrier"; } + void Compute(Instruction* instruction) const override { Run(instruction->instr_msg()); } + void ComputeInFuseMode(InstructionMsg* instr_msg) const override { Run(*instr_msg); } + + protected: + void Run(const InstructionMsg& instr_msg) const { + const auto* operand = + dynamic_cast(instr_msg.phy_instr_operand().get()); + CHECK_NOTNULL(operand)->callback(); + } +}; + +class GlobalSyncInstructionType : public InstructionType { + public: + GlobalSyncInstructionType() = default; + virtual ~GlobalSyncInstructionType() override = default; + + bool IsBarrier() const override { return true; } + + std::string DebugName(const vm::InstructionMsg& instr_msg) const override { return "GlobalSync"; } + void Compute(Instruction* instruction) const override { OF_ENV_BARRIER(); } + void ComputeInFuseMode(InstructionMsg* instr_msg) const override { OF_ENV_BARRIER(); } +}; + +} // namespace vm +} // namespace oneflow + +#endif // ONEFLOW_CORE_VM_BARRIER_INSTRUCTION_TYPE_H_ diff --git a/oneflow/core/vm/control_stream_type.cpp b/oneflow/core/vm/control_stream_type.cpp index 931f9b2ae2b..f007ea33812 100644 --- a/oneflow/core/vm/control_stream_type.cpp +++ b/oneflow/core/vm/control_stream_type.cpp @@ -13,7 +13,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "oneflow/core/vm/stream_desc.h" #include "oneflow/core/vm/control_stream_type.h" #include "oneflow/core/vm/instruction_type.h" #include "oneflow/core/vm/instruction.h" @@ -27,8 +26,7 @@ namespace oneflow { namespace vm { void ControlStreamType::Compute(Instruction* instruction) const { - const auto& instr_type_id = instruction->instr_msg().instr_type_id(); - instr_type_id.instruction_type().Compute(instruction); + instruction->instr_msg().instruction_type().Compute(instruction); auto* status_buffer = instruction->mut_status_buffer(); NaiveInstrStatusQuerier::MutCast(status_buffer->mut_buffer()->mut_data())->set_done(); } @@ -50,14 +48,5 @@ bool ControlStreamType::QueryInstructionStatusDone( return NaiveInstrStatusQuerier::Cast(status_buffer.buffer().data())->done(); } -intrusive::shared_ptr ControlStreamType::MakeStreamDesc(const Resource& resource, - int64_t this_machine_id) const { - auto ret = intrusive::make_shared(); - ret->set_stream_type(StaticGlobalStreamType()); - ret->set_num_streams_per_machine(1); - ret->set_num_streams_per_thread(1); - return ret; -} - } // namespace vm } // namespace oneflow diff --git a/oneflow/core/vm/control_stream_type.h b/oneflow/core/vm/control_stream_type.h index a5e66dcd6a5..622bf318d93 100644 --- a/oneflow/core/vm/control_stream_type.h +++ b/oneflow/core/vm/control_stream_type.h @@ -29,8 +29,6 @@ class ControlStreamType final : public StreamType { ControlStreamType() = default; ~ControlStreamType() = default; - const char* stream_tag() const override { return "control"; } - void InitDeviceCtx(std::unique_ptr* device_ctx, Stream* stream) const override {} void InitInstructionStatus(const Stream& stream, @@ -39,8 +37,6 @@ class ControlStreamType final : public StreamType { InstructionStatusBuffer* status_buffer) const override; bool QueryInstructionStatusDone(const Stream& stream, const InstructionStatusBuffer& status_buffer) const override; - intrusive::shared_ptr MakeStreamDesc(const Resource& resource, - int64_t this_machine_id) const override; void Compute(Instruction* instruction) const override; bool OnSchedulerThread() const override { return true; } diff --git a/oneflow/core/vm/cpu_stream_type.cpp b/oneflow/core/vm/cpu_stream_type.cpp index ca61f0aba73..8e04d05f8ba 100644 --- a/oneflow/core/vm/cpu_stream_type.cpp +++ b/oneflow/core/vm/cpu_stream_type.cpp @@ -49,24 +49,10 @@ bool CpuStreamType::QueryInstructionStatusDone(const Stream& stream, void CpuStreamType::Compute(Instruction* instruction) const { OF_PROFILER_RANGE_GUARD("S:" + instruction->instr_msg().DebugName()); - { - const auto& instr_type_id = instruction->mut_instr_msg()->instr_type_id(); - instr_type_id.instruction_type().Compute(instruction); - } + instruction->instr_msg().instruction_type().Compute(instruction); auto* status_buffer = instruction->mut_status_buffer(); NaiveInstrStatusQuerier::MutCast(status_buffer->mut_buffer()->mut_data())->set_done(); } -intrusive::shared_ptr CpuStreamType::MakeStreamDesc(const Resource& resource, - int64_t this_machine_id) const { - if (!resource.has_cpu_device_num()) { return intrusive::shared_ptr(); } - std::size_t device_num = resource.cpu_device_num(); - auto ret = intrusive::make_shared(); - ret->set_stream_type(StaticGlobalStreamType()); - ret->set_num_streams_per_machine(device_num); - ret->set_num_streams_per_thread(device_num); - return ret; -} - } // namespace vm } // namespace oneflow diff --git a/oneflow/core/vm/cpu_stream_type.h b/oneflow/core/vm/cpu_stream_type.h index 304f1ff29e7..f94226ac7c1 100644 --- a/oneflow/core/vm/cpu_stream_type.h +++ b/oneflow/core/vm/cpu_stream_type.h @@ -30,8 +30,6 @@ class CpuStreamType final : public StreamType { CpuStreamType() = default; ~CpuStreamType() override = default; - const char* stream_tag() const override { return "cpu"; } - void InitDeviceCtx(std::unique_ptr* device_ctx, Stream* stream) const override; void InitInstructionStatus(const Stream& stream, @@ -41,8 +39,6 @@ class CpuStreamType final : public StreamType { bool QueryInstructionStatusDone(const Stream& stream, const InstructionStatusBuffer& status_buffer) const override; void Compute(Instruction* instruction) const override; - intrusive::shared_ptr MakeStreamDesc(const Resource& resource, - int64_t this_machine_id) const override; bool OnSchedulerThread() const override { return false; } bool SupportingTransportInstructions() const override { return true; } }; diff --git a/oneflow/core/eager/critical_section_status_querier.h b/oneflow/core/vm/critical_section_status_querier.h similarity index 91% rename from oneflow/core/eager/critical_section_status_querier.h rename to oneflow/core/vm/critical_section_status_querier.h index 6b5293a7789..8e26fccf4d1 100644 --- a/oneflow/core/eager/critical_section_status_querier.h +++ b/oneflow/core/vm/critical_section_status_querier.h @@ -13,8 +13,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifndef ONEFLOW_CORE_EAGER_CRITICAL_SECTION_QUERIER_H_ -#define ONEFLOW_CORE_EAGER_CRITICAL_SECTION_QUERIER_H_ +#ifndef ONEFLOW_CORE_VM_CRITICAL_SECTION_QUERIER_H_ +#define ONEFLOW_CORE_VM_CRITICAL_SECTION_QUERIER_H_ #include #include @@ -58,4 +58,4 @@ class CriticalSectionStatusQuerier final { } // namespace vm } // namespace oneflow -#endif // ONEFLOW_CORE_EAGER_CRITICAL_SECTION_QUERIER_H_ +#endif // ONEFLOW_CORE_VM_CRITICAL_SECTION_QUERIER_H_ diff --git a/oneflow/core/eager/critical_section_stream_type.cpp b/oneflow/core/vm/critical_section_stream_type.cpp similarity index 75% rename from oneflow/core/eager/critical_section_stream_type.cpp rename to oneflow/core/vm/critical_section_stream_type.cpp index 86f9a7a8b72..b718fafc220 100644 --- a/oneflow/core/eager/critical_section_stream_type.cpp +++ b/oneflow/core/vm/critical_section_stream_type.cpp @@ -14,11 +14,11 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "oneflow/core/eager/critical_section_stream_type.h" +#include "oneflow/core/vm/critical_section_stream_type.h" #include "oneflow/core/vm/instruction_type.h" #include "oneflow/core/vm/instruction.h" #include "oneflow/core/vm/thread_ctx.h" -#include "oneflow/core/eager/critical_section_status_querier.h" +#include "oneflow/core/vm/critical_section_status_querier.h" #include "oneflow/core/common/util.h" namespace oneflow { @@ -47,19 +47,7 @@ bool CriticalSectionStreamType::QueryInstructionStatusDone( } void CriticalSectionStreamType::Compute(Instruction* instruction) const { - { - const auto& instr_type_id = instruction->mut_instr_msg()->instr_type_id(); - instr_type_id.instruction_type().Compute(instruction); - } -} - -intrusive::shared_ptr CriticalSectionStreamType::MakeStreamDesc( - const Resource& resource, int64_t this_machine_id) const { - auto ret = intrusive::make_shared(); - ret->set_stream_type(StaticGlobalStreamType()); - ret->set_num_streams_per_machine(1); - ret->set_num_streams_per_thread(1); - return ret; + instruction->instr_msg().instruction_type().Compute(instruction); } } // namespace vm diff --git a/oneflow/core/eager/critical_section_stream_type.h b/oneflow/core/vm/critical_section_stream_type.h similarity index 80% rename from oneflow/core/eager/critical_section_stream_type.h rename to oneflow/core/vm/critical_section_stream_type.h index b71ace70090..f4ad4e9a5e7 100644 --- a/oneflow/core/eager/critical_section_stream_type.h +++ b/oneflow/core/vm/critical_section_stream_type.h @@ -14,8 +14,8 @@ See the License for the specific language governing permissions and limitations under the License. */ -#ifndef ONEFLOW_CORE_EAGER_CRITICAL_SECTION_STREAM_TYPE_H_ -#define ONEFLOW_CORE_EAGER_CRITICAL_SECTION_STREAM_TYPE_H_ +#ifndef ONEFLOW_CORE_VM_CRITICAL_SECTION_STREAM_TYPE_H_ +#define ONEFLOW_CORE_VM_CRITICAL_SECTION_STREAM_TYPE_H_ #include "oneflow/core/intrusive/flat_msg_view.h" #include "oneflow/core/vm/stream_type.h" @@ -31,8 +31,6 @@ class CriticalSectionStreamType final : public StreamType { CriticalSectionStreamType() = default; virtual ~CriticalSectionStreamType() = default; - const char* stream_tag() const override { return "critical_section"; } - void InitDeviceCtx(std::unique_ptr* device_ctx, Stream* stream) const override; void InitInstructionStatus(const Stream& stream, @@ -44,11 +42,9 @@ class CriticalSectionStreamType final : public StreamType { void Compute(Instruction* instruction) const override; bool OnSchedulerThread() const override { return false; } bool SupportingTransportInstructions() const override { return false; } - intrusive::shared_ptr MakeStreamDesc(const Resource& resource, - int64_t this_machine_id) const override; }; } // namespace vm } // namespace oneflow -#endif // ONEFLOW_CORE_EAGER_CRITICAL_SECTION_STREAM_TYPE_H_ +#endif // ONEFLOW_CORE_VM_CRITICAL_SECTION_STREAM_TYPE_H_ diff --git a/oneflow/core/vm/cuda_copy_d2h_stream_type.cpp b/oneflow/core/vm/cuda_copy_d2h_stream_type.cpp index ee1acaaeb49..2437b5d3521 100644 --- a/oneflow/core/vm/cuda_copy_d2h_stream_type.cpp +++ b/oneflow/core/vm/cuda_copy_d2h_stream_type.cpp @@ -55,27 +55,12 @@ bool CudaCopyD2HStreamType::QueryInstructionStatusDone( void CudaCopyD2HStreamType::Compute(Instruction* instruction) const { auto* stream = instruction->mut_stream(); cudaSetDevice(stream->device_id()); - { - const auto& instr_type_id = instruction->mut_instr_msg()->instr_type_id(); - instr_type_id.instruction_type().Compute(instruction); - OF_CUDA_CHECK(cudaGetLastError()); - } + instruction->instr_msg().instruction_type().Compute(instruction); + OF_CUDA_CHECK(cudaGetLastError()); char* data_ptr = instruction->mut_status_buffer()->mut_buffer()->mut_data(); CudaOptionalEventRecordStatusQuerier::MutCast(data_ptr)->SetLaunched(stream->device_ctx().get()); } -// Specifies copy_d2h stream description of the virtual machine to be used. -intrusive::shared_ptr CudaCopyD2HStreamType::MakeStreamDesc( - const Resource& resource, int64_t this_machine_id) const { - if (!resource.has_gpu_device_num()) { return intrusive::shared_ptr(); } - std::size_t device_num = resource.gpu_device_num(); - auto ret = intrusive::make_shared(); - ret->set_stream_type(StaticGlobalStreamType()); - ret->set_num_streams_per_machine(device_num); - ret->set_num_streams_per_thread(device_num); - return ret; -} - } // namespace vm } // namespace oneflow diff --git a/oneflow/core/vm/cuda_copy_d2h_stream_type.h b/oneflow/core/vm/cuda_copy_d2h_stream_type.h index 4ba2bc3cfa0..c8039af3537 100644 --- a/oneflow/core/vm/cuda_copy_d2h_stream_type.h +++ b/oneflow/core/vm/cuda_copy_d2h_stream_type.h @@ -37,8 +37,6 @@ class CudaCopyD2HStreamType final : public StreamType { CudaCopyD2HStreamType() = default; ~CudaCopyD2HStreamType() = default; - const char* stream_tag() const override { return "cuda_d2h"; } - void InitDeviceCtx(std::unique_ptr* device_ctx, Stream* stream) const override; void InitInstructionStatus(const Stream& stream, @@ -48,8 +46,6 @@ class CudaCopyD2HStreamType final : public StreamType { bool QueryInstructionStatusDone(const Stream& stream, const InstructionStatusBuffer& status_buffer) const override; void Compute(Instruction* instruction) const override; - intrusive::shared_ptr MakeStreamDesc(const Resource& resource, - int64_t this_machine_id) const override; bool OnSchedulerThread() const override { return true; } bool SupportingTransportInstructions() const override { return false; } }; diff --git a/oneflow/core/vm/cuda_copy_h2d_stream_type.cpp b/oneflow/core/vm/cuda_copy_h2d_stream_type.cpp index 84dcc316457..8bfba60c214 100644 --- a/oneflow/core/vm/cuda_copy_h2d_stream_type.cpp +++ b/oneflow/core/vm/cuda_copy_h2d_stream_type.cpp @@ -49,26 +49,12 @@ bool CudaCopyH2DStreamType::QueryInstructionStatusDone( void CudaCopyH2DStreamType::Compute(Instruction* instruction) const { auto* stream = instruction->mut_stream(); cudaSetDevice(stream->device_id()); - { - const auto& instr_type_id = instruction->mut_instr_msg()->instr_type_id(); - instr_type_id.instruction_type().Compute(instruction); - OF_CUDA_CHECK(cudaGetLastError()); - } + instruction->instr_msg().instruction_type().Compute(instruction); + OF_CUDA_CHECK(cudaGetLastError()); char* data_ptr = instruction->mut_status_buffer()->mut_buffer()->mut_data(); CudaOptionalEventRecordStatusQuerier::MutCast(data_ptr)->SetLaunched(stream->device_ctx().get()); } -intrusive::shared_ptr CudaCopyH2DStreamType::MakeStreamDesc( - const Resource& resource, int64_t this_machine_id) const { - if (!resource.has_gpu_device_num()) { return intrusive::shared_ptr(); } - std::size_t device_num = resource.gpu_device_num(); - auto ret = intrusive::make_shared(); - ret->set_stream_type(StaticGlobalStreamType()); - ret->set_num_streams_per_machine(device_num); - ret->set_num_streams_per_thread(device_num); - return ret; -} - } // namespace vm } // namespace oneflow diff --git a/oneflow/core/vm/cuda_copy_h2d_stream_type.h b/oneflow/core/vm/cuda_copy_h2d_stream_type.h index 24237260544..22e6180b0eb 100644 --- a/oneflow/core/vm/cuda_copy_h2d_stream_type.h +++ b/oneflow/core/vm/cuda_copy_h2d_stream_type.h @@ -36,8 +36,6 @@ class CudaCopyH2DStreamType final : public StreamType { CudaCopyH2DStreamType() = default; ~CudaCopyH2DStreamType() = default; - const char* stream_tag() const override { return "cuda_h2d"; } - void InitDeviceCtx(std::unique_ptr* device_ctx, Stream* stream) const override; void InitInstructionStatus(const Stream& stream, @@ -47,8 +45,6 @@ class CudaCopyH2DStreamType final : public StreamType { bool QueryInstructionStatusDone(const Stream& stream, const InstructionStatusBuffer& status_buffer) const override; void Compute(Instruction* instruction) const override; - intrusive::shared_ptr MakeStreamDesc(const Resource& resource, - int64_t this_machine_id) const override; bool OnSchedulerThread() const override { return true; } bool SupportingTransportInstructions() const override { return false; } }; diff --git a/oneflow/core/vm/cuda_stream_type.cpp b/oneflow/core/vm/cuda_stream_type.cpp index 671986aa5ae..0498e1680c3 100644 --- a/oneflow/core/vm/cuda_stream_type.cpp +++ b/oneflow/core/vm/cuda_stream_type.cpp @@ -55,27 +55,13 @@ void CudaStreamType::Compute(Instruction* instruction) const { OF_PROFILER_RANGE_PUSH("S:" + instruction->instr_msg().DebugName()); auto* stream = instruction->mut_stream(); cudaSetDevice(stream->device_id()); - { - const auto& instr_type_id = instruction->mut_instr_msg()->instr_type_id(); - instr_type_id.instruction_type().Compute(instruction); - OF_CUDA_CHECK(cudaGetLastError()); - } + instruction->instr_msg().instruction_type().Compute(instruction); + OF_CUDA_CHECK(cudaGetLastError()); char* data_ptr = instruction->mut_status_buffer()->mut_buffer()->mut_data(); CudaOptionalEventRecordStatusQuerier::MutCast(data_ptr)->SetLaunched(stream->device_ctx().get()); OF_PROFILER_RANGE_POP(); } -intrusive::shared_ptr CudaStreamType::MakeStreamDesc(const Resource& resource, - int64_t this_machine_id) const { - if (!resource.has_gpu_device_num()) { return intrusive::shared_ptr(); } - std::size_t device_num = resource.gpu_device_num(); - auto ret = intrusive::make_shared(); - ret->set_stream_type(StaticGlobalStreamType()); - ret->set_num_streams_per_machine(device_num); - ret->set_num_streams_per_thread(device_num); - return ret; -} - } // namespace vm } // namespace oneflow diff --git a/oneflow/core/vm/cuda_stream_type.h b/oneflow/core/vm/cuda_stream_type.h index 9dce5146827..cfaf855f486 100644 --- a/oneflow/core/vm/cuda_stream_type.h +++ b/oneflow/core/vm/cuda_stream_type.h @@ -32,8 +32,6 @@ class CudaStreamType final : public StreamType { CudaStreamType() = default; ~CudaStreamType() override = default; - const char* stream_tag() const override { return "cuda"; } - void InitDeviceCtx(std::unique_ptr* device_ctx, Stream* stream) const override; void InitInstructionStatus(const Stream& stream, @@ -43,8 +41,6 @@ class CudaStreamType final : public StreamType { bool QueryInstructionStatusDone(const Stream& stream, const InstructionStatusBuffer& status_buffer) const override; void Compute(Instruction* instruction) const override; - intrusive::shared_ptr MakeStreamDesc(const Resource& resource, - int64_t this_machine_id) const override; bool OnSchedulerThread() const override { return true; } bool SupportingTransportInstructions() const override { return true; } }; diff --git a/oneflow/core/vm/async_cuda_stream_type.cpp b/oneflow/core/vm/event_recorded_cuda_stream_type.cpp similarity index 60% rename from oneflow/core/vm/async_cuda_stream_type.cpp rename to oneflow/core/vm/event_recorded_cuda_stream_type.cpp index e18bd824224..161cec36ef1 100644 --- a/oneflow/core/vm/async_cuda_stream_type.cpp +++ b/oneflow/core/vm/event_recorded_cuda_stream_type.cpp @@ -15,7 +15,7 @@ limitations under the License. */ #ifdef WITH_CUDA -#include "oneflow/core/vm/async_cuda_stream_type.h" +#include "oneflow/core/vm/event_recorded_cuda_stream_type.h" #include "oneflow/core/vm/instruction_type.h" #include "oneflow/core/vm/stream.h" #include "oneflow/core/vm/cuda_stream_handle_device_context.h" @@ -25,13 +25,13 @@ limitations under the License. namespace oneflow { namespace vm { -void AsyncCudaStreamType::InitDeviceCtx(std::unique_ptr* device_ctx, - Stream* stream) const { +void EventRecordedCudaStreamType::InitDeviceCtx(std::unique_ptr* device_ctx, + Stream* stream) const { device_ctx->reset(new CudaStreamHandleDeviceCtx(stream->device_id())); } -void AsyncCudaStreamType::InitInstructionStatus(const Stream& stream, - InstructionStatusBuffer* status_buffer) const { +void EventRecordedCudaStreamType::InitInstructionStatus( + const Stream& stream, InstructionStatusBuffer* status_buffer) const { static_assert(sizeof(CudaOptionalEventRecordStatusQuerier) < kInstructionStatusBufferBytes, ""); auto* event_provider = dynamic_cast(stream.device_ctx().get()); auto* data_ptr = status_buffer->mut_buffer()->mut_data(); @@ -39,42 +39,28 @@ void AsyncCudaStreamType::InitInstructionStatus(const Stream& stream, CudaOptionalEventRecordStatusQuerier::PlacementNew(data_ptr, cuda_event); } -void AsyncCudaStreamType::DeleteInstructionStatus(const Stream& stream, - InstructionStatusBuffer* status_buffer) const { +void EventRecordedCudaStreamType::DeleteInstructionStatus( + const Stream& stream, InstructionStatusBuffer* status_buffer) const { auto* ptr = CudaOptionalEventRecordStatusQuerier::MutCast(status_buffer->mut_buffer()->mut_data()); ptr->~CudaOptionalEventRecordStatusQuerier(); } -bool AsyncCudaStreamType::QueryInstructionStatusDone( +bool EventRecordedCudaStreamType::QueryInstructionStatusDone( const Stream& stream, const InstructionStatusBuffer& status_buffer) const { return CudaOptionalEventRecordStatusQuerier::Cast(status_buffer.buffer().data())->done(); } -void AsyncCudaStreamType::Compute(Instruction* instruction) const { +void EventRecordedCudaStreamType::Compute(Instruction* instruction) const { OF_PROFILER_RANGE_GUARD("S:" + instruction->instr_msg().DebugName()); auto* stream = instruction->mut_stream(); cudaSetDevice(stream->device_id()); - { - const auto& instr_type_id = instruction->mut_instr_msg()->instr_type_id(); - instr_type_id.instruction_type().Compute(instruction); - OF_CUDA_CHECK(cudaGetLastError()); - } + instruction->instr_msg().instruction_type().Compute(instruction); + OF_CUDA_CHECK(cudaGetLastError()); char* data_ptr = instruction->mut_status_buffer()->mut_buffer()->mut_data(); CudaOptionalEventRecordStatusQuerier::MutCast(data_ptr)->SetLaunched(stream->device_ctx().get()); } -intrusive::shared_ptr AsyncCudaStreamType::MakeStreamDesc( - const Resource& resource, int64_t this_machine_id) const { - if (!resource.has_gpu_device_num()) { return intrusive::shared_ptr(); } - std::size_t device_num = resource.gpu_device_num(); - auto ret = intrusive::make_shared(); - ret->set_stream_type(StaticGlobalStreamType()); - ret->set_num_streams_per_machine(device_num); - ret->set_num_streams_per_thread(device_num); - return ret; -} - } // namespace vm } // namespace oneflow diff --git a/oneflow/core/vm/async_cuda_stream_type.h b/oneflow/core/vm/event_recorded_cuda_stream_type.h similarity index 75% rename from oneflow/core/vm/async_cuda_stream_type.h rename to oneflow/core/vm/event_recorded_cuda_stream_type.h index 52094e4b578..238f2c505ab 100644 --- a/oneflow/core/vm/async_cuda_stream_type.h +++ b/oneflow/core/vm/event_recorded_cuda_stream_type.h @@ -15,8 +15,8 @@ limitations under the License. */ #ifdef WITH_CUDA -#ifndef ONEFLOW_CORE_VM_ASYNC_CUDA_STREAM_TYPE_H_ -#define ONEFLOW_CORE_VM_ASYNC_CUDA_STREAM_TYPE_H_ +#ifndef ONEFLOW_CORE_VM_EVENT_RECORDED_CUDA_STREAM_TYPE_H_ +#define ONEFLOW_CORE_VM_EVENT_RECORDED_CUDA_STREAM_TYPE_H_ #include "oneflow/core/intrusive/flat_msg_view.h" #include "oneflow/core/vm/stream_type.h" @@ -27,12 +27,10 @@ limitations under the License. namespace oneflow { namespace vm { -class AsyncCudaStreamType final : public StreamType { +class EventRecordedCudaStreamType final : public StreamType { public: - AsyncCudaStreamType() = default; - ~AsyncCudaStreamType() override = default; - - const char* stream_tag() const override { return "async_launched_nccl"; } + EventRecordedCudaStreamType() = default; + ~EventRecordedCudaStreamType() override = default; void InitDeviceCtx(std::unique_ptr* device_ctx, Stream* stream) const override; @@ -43,8 +41,6 @@ class AsyncCudaStreamType final : public StreamType { bool QueryInstructionStatusDone(const Stream& stream, const InstructionStatusBuffer& status_buffer) const override; void Compute(Instruction* instruction) const override; - intrusive::shared_ptr MakeStreamDesc(const Resource& resource, - int64_t this_machine_id) const override; bool OnSchedulerThread() const override { return true; } bool SupportingTransportInstructions() const override { return true; } }; @@ -52,5 +48,5 @@ class AsyncCudaStreamType final : public StreamType { } // namespace vm } // namespace oneflow -#endif // ONEFLOW_CORE_VM_ASYNC_CUDA_STREAM_TYPE_H_ +#endif // ONEFLOW_CORE_VM_EVENT_RECORDED_CUDA_STREAM_TYPE_H_ #endif // WITH_CUDA diff --git a/oneflow/core/vm/fuse_instruction_type.cpp b/oneflow/core/vm/fuse_instruction_type.h similarity index 58% rename from oneflow/core/vm/fuse_instruction_type.cpp rename to oneflow/core/vm/fuse_instruction_type.h index fe2d060b69b..25fd45bb127 100644 --- a/oneflow/core/vm/fuse_instruction_type.cpp +++ b/oneflow/core/vm/fuse_instruction_type.h @@ -13,28 +13,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#ifndef ONEFLOW_CORE_VM_FUSE_INSTRUCTION_TYPE_H_ +#define ONEFLOW_CORE_VM_FUSE_INSTRUCTION_TYPE_H_ + #include "oneflow/core/vm/instruction.h" #include "oneflow/core/vm/fuse_phy_instr_operand.h" -#include "oneflow/core/vm/cuda_stream_type.h" -#include "oneflow/core/vm/async_cuda_stream_type.h" -#include "oneflow/core/vm/cuda_copy_h2d_stream_type.h" -#include "oneflow/core/vm/cuda_copy_d2h_stream_type.h" -#include "oneflow/core/vm/cpu_stream_type.h" #include "oneflow/core/profiler/profiler.h" namespace oneflow { namespace vm { -template class FuseInstructionType : public vm::InstructionType { public: FuseInstructionType() = default; ~FuseInstructionType() override = default; - using stream_type = StreamT; - - std::string DebugOpTypeName(const InstructionMsg&) const override { return "Fuse"; } + std::string DebugName(const InstructionMsg&) const override { return "Fuse"; } void InitInstructionStatus(Instruction* instruction) const override { const auto& phy_instr_operand = instruction->instr_msg().phy_instr_operand(); @@ -42,7 +37,7 @@ class FuseInstructionType : public vm::InstructionType { auto* instr_msg_list = CHECK_NOTNULL(ptr)->mut_instr_msg_list(); auto* last_instr_msg = CHECK_NOTNULL(instr_msg_list->Last()); // init instruction status by last instruction_msg. - last_instr_msg->instr_type_id().instruction_type().InitInstructionStatusIf(instruction); + last_instr_msg->instruction_type().InitInstructionStatusIf(instruction); } void Compute(vm::Instruction* instruction) const override { @@ -51,23 +46,12 @@ class FuseInstructionType : public vm::InstructionType { auto* instr_msg_list = CHECK_NOTNULL(ptr)->mut_instr_msg_list(); INTRUSIVE_UNSAFE_FOR_EACH_PTR(instr_msg, instr_msg_list) { OF_PROFILER_RANGE_GUARD("F:" + instr_msg->DebugName()); - instr_msg->instr_type_id().instruction_type().ComputeInFuseMode(instr_msg); + instr_msg->instruction_type().ComputeInFuseMode(instr_msg); } } }; -COMMAND(vm::RegisterInstructionType>("cpu.Fuse")); -COMMAND(vm::RegisterInstructionType>("comm_net.Fuse")); - -#ifdef WITH_CUDA -COMMAND(vm::RegisterInstructionType>("cuda.Fuse")); -COMMAND(vm::RegisterInstructionType>("cuda_h2d.Fuse")); -COMMAND(vm::RegisterInstructionType>("cuda_d2h.Fuse")); -COMMAND( - vm::RegisterInstructionType>("sync_launched_nccl.Fuse")); -COMMAND(vm::RegisterInstructionType>( - "async_launched_nccl.Fuse")); -#endif - } // namespace vm } // namespace oneflow + +#endif // ONEFLOW_CORE_VM_FUSE_INSTRUCTION_TYPE_H_ diff --git a/oneflow/core/vm/fuse_phy_instr_operand.h b/oneflow/core/vm/fuse_phy_instr_operand.h index b9af5ae0004..258ab206f03 100644 --- a/oneflow/core/vm/fuse_phy_instr_operand.h +++ b/oneflow/core/vm/fuse_phy_instr_operand.h @@ -35,13 +35,10 @@ class FusePhyInstrOperand : public PhyInstrOperand { auto* last_instr_msg = instr_msg_list_.Last(); INTRUSIVE_UNSAFE_FOR_EACH_PTR(instr_msg, &instr_msg_list_) { if (instr_msg == last_instr_msg) { - CHECK(instr_msg->instr_type_id().instruction_type().fuse_type() - == kEnableInstructionFuseAsTailOnly - || instr_msg->instr_type_id().instruction_type().fuse_type() - == kEnableInstructionFuseAtAnyPosition); + CHECK(instr_msg->instruction_type().fuse_type() == kEnableInstructionFuseAsTailOnly + || instr_msg->instruction_type().fuse_type() == kEnableInstructionFuseAtAnyPosition); } else { - CHECK(instr_msg->instr_type_id().instruction_type().fuse_type() - == kEnableInstructionFuseAtAnyPosition); + CHECK(instr_msg->instruction_type().fuse_type() == kEnableInstructionFuseAtAnyPosition); } if (unlikely(stream_sequential_dependence_ == nullptr)) { stream_sequential_dependence_ = diff --git a/oneflow/core/vm/id_generator.cpp b/oneflow/core/vm/id_generator.cpp deleted file mode 100644 index 61232a5b082..00000000000 --- a/oneflow/core/vm/id_generator.cpp +++ /dev/null @@ -1,44 +0,0 @@ -/* -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -#include "oneflow/core/control/global_process_ctx.h" -#include "oneflow/core/vm/id_generator.h" -#include "oneflow/core/vm/id_util.h" - -namespace oneflow { -namespace vm { - -Maybe LogicalIdGenerator::NewSymbolId() { - // NOTE(chengcheng): in Multi-Client LogicalIdGenerator will degenerate directly to - // PhysicalIdGenerator, because each rank will generate id ONLY from itself, NOT the master. - return IdUtil::NewPhysicalSymbolId(GlobalProcessCtx::Rank()); -} - -Maybe LogicalIdGenerator::NewObjectId() { - // NOTE(chengcheng): in Multi-Client LogicalIdGenerator will degenerate directly to - // PhysicalIdGenerator, because each rank will generate id ONLY from itself, NOT the master. - return IdUtil::NewPhysicalObjectId(GlobalProcessCtx::Rank()); -} - -Maybe PhysicalIdGenerator::NewSymbolId() { - return IdUtil::NewPhysicalSymbolId(GlobalProcessCtx::Rank()); -} - -Maybe PhysicalIdGenerator::NewObjectId() { - return IdUtil::NewPhysicalObjectId(GlobalProcessCtx::Rank()); -} - -} // namespace vm -} // namespace oneflow diff --git a/oneflow/core/vm/id_generator.h b/oneflow/core/vm/id_generator.h deleted file mode 100644 index 58a03a3d898..00000000000 --- a/oneflow/core/vm/id_generator.h +++ /dev/null @@ -1,60 +0,0 @@ -/* -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -#ifndef ONEFLOW_CORE_VM_ID_GENERATOR_H_ -#define ONEFLOW_CORE_VM_ID_GENERATOR_H_ - -#include "oneflow/core/common/maybe.h" - -namespace oneflow { -namespace vm { - -class IdGenerator { - public: - virtual ~IdGenerator() = default; - - virtual Maybe NewSymbolId() = 0; - virtual Maybe NewObjectId() = 0; - - protected: - IdGenerator() = default; -}; - -class LogicalIdGenerator : public IdGenerator { - public: - LogicalIdGenerator(const LogicalIdGenerator&) = delete; - LogicalIdGenerator(LogicalIdGenerator&&) = delete; - LogicalIdGenerator() = default; - ~LogicalIdGenerator() override = default; - - Maybe NewSymbolId() override; - Maybe NewObjectId() override; -}; - -class PhysicalIdGenerator : public IdGenerator { - public: - PhysicalIdGenerator(const PhysicalIdGenerator&) = delete; - PhysicalIdGenerator(PhysicalIdGenerator&&) = delete; - PhysicalIdGenerator() = default; - ~PhysicalIdGenerator() override = default; - - Maybe NewSymbolId() override; - Maybe NewObjectId() override; -}; - -} // namespace vm -} // namespace oneflow - -#endif // ONEFLOW_CORE_VM_ID_GENERATOR_H_ diff --git a/oneflow/core/vm/id_util.cpp b/oneflow/core/vm/id_util.cpp deleted file mode 100644 index 5191f04514c..00000000000 --- a/oneflow/core/vm/id_util.cpp +++ /dev/null @@ -1,91 +0,0 @@ -/* -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -#include -#include -#include "oneflow/core/vm/id_util.h" - -namespace oneflow { -namespace vm { - -namespace { - -static const int64_t kObjectIdMaximumValue = LLONG_MAX / 2; -static const int64_t kMachineNumberLimit = (1 << 12); -static const int64_t kErrorCodeLimit = 4096; - -static_assert(kMachineNumberLimit >= kErrorCodeLimit, ""); - -int64_t ObjectIdCounter() { - static int64_t counter = 0; - return (counter += kMachineNumberLimit); -} - -int64_t NewLogicalObjectIdFromCounter() { return ObjectIdCounter() + kMachineNumberLimit - 1; } - -int64_t NewPhysicalObjectIdFromCounter(int32_t machine_id) { - CHECK_LT(machine_id, kMachineNumberLimit - 1); - return ObjectIdCounter() + machine_id; -} - -} // namespace - -int64_t IdUtil::IsErrorId(int64_t id) { return id >= -kErrorCodeLimit && id <= kErrorCodeLimit; } - -int64_t IdUtil::NewLogicalValueObjectId() { - int64_t val = NewLogicalObjectIdFromCounter(); - CHECK_LT(val, kObjectIdMaximumValue); - return val; -} - -int64_t IdUtil::NewLogicalValueSymbolId() { - return NewLogicalObjectIdFromCounter() + kObjectIdMaximumValue; -} - -int64_t IdUtil::IsLogicalValueId(int64_t id) { - CHECK(IsValueId(id)); - return ((id + 1) % kObjectIdMaximumValue) == 0; -} - -int64_t IdUtil::NewPhysicalValueObjectId(int32_t machine_id) { - int64_t val = NewPhysicalObjectIdFromCounter(machine_id); - CHECK_LT(val, kObjectIdMaximumValue); - return val; -} - -int64_t IdUtil::NewPhysicalValueSymbolId(int32_t machine_id) { - return NewPhysicalObjectIdFromCounter(machine_id) + kObjectIdMaximumValue; -} - -bool IdUtil::IsObjectId(int64_t object_id) { return object_id < kObjectIdMaximumValue; } - -bool IdUtil::IsSymbolId(int64_t symbol_id) { return symbol_id > kObjectIdMaximumValue; } - -int64_t IdUtil::GetTypeId(int64_t id) { - if (IsTypeId(id)) { return id; } - return -id; -} - -bool IdUtil::IsTypeId(int64_t id) { return id < 0; } - -int64_t IdUtil::GetValueId(int64_t id) { - if (IsValueId(id)) { return id; } - return -id; -} - -bool IdUtil::IsValueId(int64_t id) { return id > 0; } - -} // namespace vm -} // namespace oneflow diff --git a/oneflow/core/vm/id_util.h b/oneflow/core/vm/id_util.h deleted file mode 100644 index ccd515ecde9..00000000000 --- a/oneflow/core/vm/id_util.h +++ /dev/null @@ -1,64 +0,0 @@ -/* -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -#ifndef ONEFLOW_CORE_VM_LOGICAL_OBJECT_ID_H_ -#define ONEFLOW_CORE_VM_LOGICAL_OBJECT_ID_H_ - -#include -#include "oneflow/core/intrusive/flat_msg.h" - -namespace oneflow { -namespace vm { - -using ObjectId = int64_t; - -struct IdUtil final { - // usually [-4096, 4096] - static int64_t IsErrorId(int64_t id); - - static int64_t IsLogicalId(int64_t id) { return IsLogicalValueId(id); } - static int64_t NewLogicalObjectId() { return NewLogicalValueObjectId(); } - static int64_t NewLogicalSymbolId() { return NewLogicalValueSymbolId(); } - static int64_t NewPhysicalObjectId(int32_t machine_id) { - return NewPhysicalValueObjectId(machine_id); - } - static int64_t NewPhysicalSymbolId(int32_t machine_id) { - return NewPhysicalValueSymbolId(machine_id); - } - - static int64_t IsLogicalValueId(int64_t id); - static int64_t NewLogicalValueObjectId(); - static int64_t NewLogicalValueSymbolId(); - static int64_t NewPhysicalValueObjectId(int32_t machine_id); - static int64_t NewPhysicalValueSymbolId(int32_t machine_id); - - // type object id or value object id - static bool IsObjectId(int64_t object_id); - // type symbol id or value symbol id - static bool IsSymbolId(int64_t symbol_id); - - // type object id or type symbol id - static int64_t GetTypeId(int64_t id); - static bool IsTypeId(int64_t id); - - // value object id or value symbol id - static int64_t GetValueId(int64_t id); - static bool IsValueId(int64_t id); -}; - -} // namespace vm -} // namespace oneflow - -#endif // ONEFLOW_CORE_VM_LOGICAL_OBJECT_ID_H_ diff --git a/oneflow/core/vm/instr_type_id.h b/oneflow/core/vm/instr_type_id.h deleted file mode 100644 index 4e41b4f8462..00000000000 --- a/oneflow/core/vm/instr_type_id.h +++ /dev/null @@ -1,81 +0,0 @@ -/* -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -#ifndef ONEFLOW_CORE_VM_INSTRUCTION_ID_H_ -#define ONEFLOW_CORE_VM_INSTRUCTION_ID_H_ - -#include -#include "oneflow/core/intrusive/flat_msg.h" -#include "oneflow/core/common/layout_standardize.h" -#include "oneflow/core/vm/stream_desc.h" - -namespace oneflow { -namespace vm { - -class InstructionType; -class StreamType; - -class InstrTypeId final { - public: - InstrTypeId() { __Init__(); } - InstrTypeId(const InstrTypeId& rhs) { - __Init__(); - CopyFrom(rhs); - } - - ~InstrTypeId() = default; - - void __Init__() { clear(); } - void __Init__(const StreamType* stream_type, const InstructionType* instruction_type) { - __Init__(); - set_stream_type(stream_type); - instruction_type_ = instruction_type; - } - void clear() { - stream_type_ = nullptr; - instruction_type_ = nullptr; - } - void CopyFrom(const InstrTypeId& rhs) { - stream_type_ = &rhs.stream_type(); - instruction_type_ = &rhs.instruction_type(); - } - // Getters - const StreamType& stream_type() const { return *stream_type_; } - const InstructionType& instruction_type() const { return *instruction_type_; } - - // Setters - void set_stream_type(const StreamType* stream_type) { stream_type_ = stream_type; } - - bool operator==(const InstrTypeId& rhs) const { - return stream_type_ == rhs.stream_type_ && instruction_type_ == rhs.instruction_type_; - } - bool operator<(const InstrTypeId& rhs) const { - if (!(stream_type_ == rhs.stream_type_)) { return stream_type_ < rhs.stream_type_; } - if (!(instruction_type_ == rhs.instruction_type_)) { - return instruction_type_ < rhs.instruction_type_; - } - return false; - } - bool operator<=(const InstrTypeId& rhs) const { return *this < rhs || *this == rhs; } - - private: - const InstructionType* instruction_type_; - const StreamType* stream_type_; -}; - -} // namespace vm -} // namespace oneflow - -#endif // ONEFLOW_CORE_VM_INSTRUCTION_ID_H_ diff --git a/oneflow/core/vm/instruction.cpp b/oneflow/core/vm/instruction.cpp index c4c7a93f6a0..300580f78a4 100644 --- a/oneflow/core/vm/instruction.cpp +++ b/oneflow/core/vm/instruction.cpp @@ -19,6 +19,7 @@ limitations under the License. #include "oneflow/core/vm/stream.h" #include "oneflow/core/vm/thread_ctx.h" #include "oneflow/core/vm/virtual_machine_engine.h" +#include "oneflow/core/framework/stream_get_stream_role_name.h" #include "oneflow/core/common/util.h" #include "oneflow/core/common/cpp_attribute.h" #include "oneflow/core/profiler/profiler.h" @@ -27,66 +28,26 @@ namespace oneflow { namespace vm { std::string InstructionMsg::DebugName() const { - std::string op_type_name = instr_type_id().instruction_type().DebugOpTypeName(*this); - return op_type_name + ":" + instr_type_name(); + std::string instr_name = instruction_type().DebugName(*this); + return instr_name + ":" + GetStreamRoleName::Visit(stream().stream_role()); } -void InstructionMsg::__Init__() { *mut_instr_type_name() = ""; } - -void InstructionMsg::__Init__(const std::string& instr_type_name) { - __Init__(); - mut_instr_type_id()->CopyFrom(LookupInstrTypeId(instr_type_name)); - *mut_instr_type_name() = instr_type_name; -} - -void InstructionMsg::__Init__(VirtualMachineEngine* vm, const std::string& instr_type_name, - const std::shared_ptr& phy_instr_parallel_desc, +void InstructionMsg::__Init__(Stream* stream, const InstructionType* instruction_type, const std::shared_ptr& phy_instr_operand) { - __Init__(); - // There are instructions without concept of ParallelDesc, like LaunchLazyJob, - // ComputeGlobalFrontSeqBarrier. If phy_instr_parallel_desc is empty, Instructions are run on the - // sole stream within the StreamRtDesc. - if (likely(phy_instr_parallel_desc)) { - int device_id = phy_instr_parallel_desc->parallel_id2device_id().at(0); - vm->GetCachedInstrTypeIdAndPhyInstrStream(instr_type_name, device_id, mut_instr_type_id(), - &phy_instr_stream_); - } else { - vm->GetInstrTypeIdAndSoleStream(instr_type_name, mut_instr_type_id(), &phy_instr_stream_); - } - *mut_instr_type_name() = instr_type_name; - phy_instr_parallel_desc_ = phy_instr_parallel_desc; + stream_ = stream; + instruction_type_ = instruction_type; phy_instr_operand_ = phy_instr_operand; } -void InstructionMsg::__Init__(const InstructionMsg& instr_msg) { - __Init__(); - mut_instr_type_id()->CopyFrom(instr_msg.instr_type_id()); - *mut_instr_type_name() = instr_msg.instr_type_name(); - const auto& parallel_desc = instr_msg.phy_instr_parallel_desc(); - if (parallel_desc) { phy_instr_parallel_desc_ = parallel_desc; } - phy_instr_operand_ = instr_msg.phy_instr_operand(); - if (instr_msg.phy_instr_stream() != nullptr) { phy_instr_stream_ = instr_msg.phy_instr_stream(); } -} - -intrusive::shared_ptr InstructionMsg::Clone() const { - return intrusive::make_shared(*this); -} - -void Instruction::Init(InstructionMsg* instr_msg, Stream* stream, - const std::shared_ptr& parallel_desc) { - __Init__(); - reset_instr_msg(instr_msg); - set_stream(stream); - instr_msg->instr_type_id().instruction_type().InitInstructionStatusIf(this); - *mut_parallel_desc() = parallel_desc; +void Instruction::Init(InstructionMsg* instr_msg) { + instr_msg_ = instr_msg; + instr_msg->instruction_type().InitInstructionStatusIf(this); } void Instruction::Delete() { OF_PROFILER_RANGE_GUARD("Instruction::Delete"); - instr_msg().instr_type_id().instruction_type().DeleteInstructionStatusIf(this); - OF_PROFILER_RANGE_PUSH("ClearInstrMsg"); + instr_msg().instruction_type().DeleteInstructionStatusIf(this); clear_instr_msg(); - OF_PROFILER_RANGE_POP(); mut_in_edges()->Clear(); mut_out_edges()->Clear(); } diff --git a/oneflow/core/vm/instruction.h b/oneflow/core/vm/instruction.h index 3b0034d97d7..0323fb36d97 100644 --- a/oneflow/core/vm/instruction.h +++ b/oneflow/core/vm/instruction.h @@ -18,48 +18,33 @@ limitations under the License. #include #include -#include "oneflow/core/job/parallel_desc.h" +#include "oneflow/core/common/symbol.h" #include "oneflow/core/intrusive/flat_msg.h" #include "oneflow/core/intrusive/intrusive.h" #include "oneflow/core/intrusive/object_pool.h" -#include "oneflow/core/vm/stream_desc.h" #include "oneflow/core/vm/vm_object.h" #include "oneflow/core/vm/stream_type.h" -#include "oneflow/core/vm/instr_type_id.h" -#include "oneflow/core/vm/id_util.h" -#include "oneflow/core/vm/instruction.pb.h" #include "oneflow/core/vm/phy_instr_operand.h" namespace oneflow { -namespace vm { -class VirtualMachineEngine; +class Stream; + +namespace vm { class InstructionMsg final : public intrusive::Base { public: - // Getters - const std::string& instr_type_name() const { return instr_type_name_; } - const InstrTypeId& instr_type_id() const { return instr_type_id_; } - const std::shared_ptr& phy_instr_parallel_desc() const { - return phy_instr_parallel_desc_; - } - const std::shared_ptr& phy_instr_operand() const { return phy_instr_operand_; } - Stream* phy_instr_stream() const { return phy_instr_stream_; } - // Setters - std::string* mut_instr_type_name() { return &instr_type_name_; } - InstrTypeId* mut_instr_type_id() { return &instr_type_id_; } - // methods - void __Init__(); - void __Init__(const std::string& instr_type_name); - void __Init__(VirtualMachineEngine* vm, const std::string& instr_type_name, - const std::shared_ptr& phy_instr_parallel_desc, + void __Init__(Stream* stream, const InstructionType* instruction_type, const std::shared_ptr& phy_instr_operand); - void __Init__(const InstructionMsg& instr_msg); - std::string DebugName() const; + // Getters + const Stream& stream() const { return *stream_; } + Stream* mut_stream() { return stream_; } + const InstructionType& instruction_type() const { return *instruction_type_; } + const std::shared_ptr& phy_instr_operand() const { return phy_instr_operand_; } - intrusive::shared_ptr Clone() const; + std::string DebugName() const; intrusive::Ref::RefCntType ref_cnt() const { return intrusive_ref_.ref_cnt(); } @@ -68,21 +53,12 @@ class InstructionMsg final : public intrusive::Base { intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; } InstructionMsg() - : intrusive_ref_(), - instr_type_id_(), - instr_type_name_(), - phy_instr_parallel_desc_(), - phy_instr_operand_(), - phy_instr_stream_(), - instr_msg_hook_() {} + : intrusive_ref_(), stream_(), instruction_type_(), phy_instr_operand_(), instr_msg_hook_() {} intrusive::Ref intrusive_ref_; // fields - InstrTypeId instr_type_id_; - // instr_type_name is a necessary reduandant field for method ToProto - std::string instr_type_name_; - std::shared_ptr phy_instr_parallel_desc_; + Stream* stream_; + const InstructionType* instruction_type_; std::shared_ptr phy_instr_operand_; - Stream* phy_instr_stream_; public: // list hooks @@ -158,15 +134,8 @@ class Instruction final : public intrusive::Base { intrusive::List; // Getters - void __Init__() { clear_stream(); } - bool has_stream() const { return stream_ != nullptr; } - const Stream& stream() const { return *stream_; } - const InstructionMsg& instr_msg() const { - if (instr_msg_) { return instr_msg_.Get(); } - static const auto default_val = intrusive::make_shared(); - return default_val.Get(); - } - const std::shared_ptr& parallel_desc() const { return parallel_desc_; } + const Stream& stream() const { return instr_msg_->stream(); } + const InstructionMsg& instr_msg() const { return instr_msg_.Get(); } const InstructionStatusBuffer& status_buffer() const { return status_buffer_.Get(); } const intrusive::ListHook& instruction_hook() const { return instruction_hook_; } const intrusive::ListHook& dispatched_instruction_hook() const { @@ -180,21 +149,17 @@ class Instruction final : public intrusive::Base { const DependenceAccessList& access_list() const { return access_list_; } // Setters - void set_stream(Stream* val) { stream_ = val; } - void clear_stream() { stream_ = nullptr; } - Stream* mut_stream() { return stream_; } + Stream* mut_stream() { return instr_msg_->mut_stream(); } InstructionMsg* mut_instr_msg() { return CHECK_NOTNULL(instr_msg_.Mutable()); } void reset_instr_msg(InstructionMsg* instr_msg) { instr_msg_.Reset(instr_msg); } void clear_instr_msg() { instr_msg_.Reset(); } - std::shared_ptr* mut_parallel_desc() { return ¶llel_desc_; } InstructionStatusBuffer* mut_status_buffer() { return status_buffer_.Mutable(); } InEdgeList* mut_in_edges() { return &in_edges_; } OutEdgeList* mut_out_edges() { return &out_edges_; } DependenceAccessList* mut_access_list() { return &access_list_; } // methods - void Init(InstructionMsg* instr_msg, Stream* stream, - const std::shared_ptr& parallel_desc); + void Init(InstructionMsg* instr_msg); void Delete(); bool Done() const; const StreamType& stream_type() const; @@ -209,8 +174,6 @@ class Instruction final : public intrusive::Base { : intrusive_ref_(), status_buffer_(), instr_msg_(), - parallel_desc_(), - stream_(), access_list_(), in_edges_(), out_edges_(), @@ -223,8 +186,6 @@ class Instruction final : public intrusive::Base { // fields FlatMsg status_buffer_; intrusive::shared_ptr instr_msg_; - std::shared_ptr parallel_desc_; - Stream* stream_; // lists DependenceAccessList access_list_; InEdgeList in_edges_; diff --git a/oneflow/core/vm/instruction.proto b/oneflow/core/vm/instruction.proto deleted file mode 100644 index 8c3d9a26495..00000000000 --- a/oneflow/core/vm/instruction.proto +++ /dev/null @@ -1,49 +0,0 @@ -syntax = "proto2"; -package oneflow.vm; - -message CurrentGlobalDeviceIdProto {} -message SoleMirroredObjectProto {} -message AllMirroredObjectProto {} - -message OperandProto { - required int64 logical_object_id = 1; - oneof operand_type { - CurrentGlobalDeviceIdProto current_global_device_id = 2; - SoleMirroredObjectProto sole_mirrored_object = 3; - AllMirroredObjectProto all_mirrored_object = 4; - } -} - -message OperandSeparatorProto { } - -message InstructionOperandProto { - oneof type { - // read only object - OperandProto const_operand = 1; - // writeable object - OperandProto mut_operand = 2; - // mut2 writeable object - OperandProto mut2_operand = 3; - OperandProto del_operand = 4; - // read only symbol - OperandProto symbol_operand = 5; - // initializable symbol - OperandProto init_symbol_operand = 6; - - OperandSeparatorProto separator = 7; - double double_operand = 8; - int64 int64_operand = 9; - uint64 uint64_operand = 10; - bool bool_operand = 11; - } -} - -message InstructionProto { - required string instr_type_name = 1; - optional int64 parallel_desc_symbol_id = 2 [default = 0]; - repeated InstructionOperandProto operand = 3; -}; - -message InstructionListProto { - repeated InstructionProto instruction = 1; -} diff --git a/oneflow/core/vm/instruction_type.cpp b/oneflow/core/vm/instruction_type.cpp index d2bb48f4ad8..174459b1f34 100644 --- a/oneflow/core/vm/instruction_type.cpp +++ b/oneflow/core/vm/instruction_type.cpp @@ -13,7 +13,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "oneflow/core/vm/instr_type_id.h" #include "oneflow/core/vm/instruction_type.h" #include "oneflow/core/vm/instruction.h" #include "oneflow/core/common/util.h" @@ -21,15 +20,6 @@ limitations under the License. namespace oneflow { namespace vm { -namespace { - -HashMap* InstrTypeId4InstructionName() { - static HashMap map; - return ↦ -} - -} // namespace - void InstructionType::InitInstructionStatus(Instruction* instruction) const { instruction->stream_type().InitInstructionStatus(instruction->stream(), instruction->mut_status_buffer()); @@ -40,23 +30,5 @@ void InstructionType::DeleteInstructionStatus(Instruction* instruction) const { instruction->mut_status_buffer()); } -const InstrTypeId& LookupInstrTypeId(const std::string& name) { - const auto& map = *InstrTypeId4InstructionName(); - const auto& iter = map.find(name); - CHECK(iter != map.end()) << "instruction type name: " << name; - return iter->second; -} - -void ForEachInstrTypeId(std::function DoEach) { - for (const auto& pair : *InstrTypeId4InstructionName()) { DoEach(pair.second); } -} - -void RegisterInstrTypeId(const std::string& instruction_name, const StreamType* stream_type, - const InstructionType* instruction_type) { - InstrTypeId instr_type_id; - instr_type_id.__Init__(stream_type, instruction_type); - CHECK(InstrTypeId4InstructionName()->emplace(instruction_name, instr_type_id).second); -} - } // namespace vm } // namespace oneflow diff --git a/oneflow/core/vm/instruction_type.h b/oneflow/core/vm/instruction_type.h index 005c57751e8..ac1f3244dee 100644 --- a/oneflow/core/vm/instruction_type.h +++ b/oneflow/core/vm/instruction_type.h @@ -36,8 +36,7 @@ class InstructionType { public: virtual ~InstructionType() = default; - bool IsSequential() const { return IsFrontSequential(); } - virtual bool IsFrontSequential() const { return false; } + virtual bool IsBarrier() const { return false; } virtual InstructionFuseType fuse_type() const { return kDisableInstructionFuse; } virtual void Compute(Instruction* instruction) const = 0; @@ -49,7 +48,7 @@ class InstructionType { DeleteInstructionStatus(instruction); } - virtual std::string DebugOpTypeName(const InstructionMsg&) const { return ""; } + virtual std::string DebugName(const InstructionMsg&) const = 0; protected: InstructionType() = default; @@ -59,28 +58,6 @@ class InstructionType { virtual void DeleteInstructionStatus(Instruction* instruction) const; }; -class InstrTypeId; -const InstrTypeId& LookupInstrTypeId(const std::string& instr_type_name); -void ForEachInstrTypeId(std::function DoEach); -void RegisterInstrTypeId(const std::string& instr_type_name, const StreamType* stream_type, - const InstructionType* instruction_type); - -template -const InstructionType* StaticGlobalInstructionType() { - static const InstructionType* instruction_type = new T(); - return instruction_type; -} - -template -void RegisterInstrTypeId(const std::string& instr_type_name, const StreamType* stream_type) { - RegisterInstrTypeId(instr_type_name, stream_type, StaticGlobalInstructionType()); -} - -template -void RegisterInstructionType(const std::string& instr_type_name) { - RegisterInstrTypeId(instr_type_name, StaticGlobalStreamType()); -} - } // namespace vm } // namespace oneflow diff --git a/oneflow/core/eager/lazy_job_device_context.h b/oneflow/core/vm/lazy_job_device_context.h similarity index 93% rename from oneflow/core/eager/lazy_job_device_context.h rename to oneflow/core/vm/lazy_job_device_context.h index d0e56590c5f..593c4f8d335 100644 --- a/oneflow/core/eager/lazy_job_device_context.h +++ b/oneflow/core/vm/lazy_job_device_context.h @@ -13,8 +13,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifndef ONEFLOW_CORE_EAGER_LAZY_JOB_DEVICE_CONTEXT_H_ -#define ONEFLOW_CORE_EAGER_LAZY_JOB_DEVICE_CONTEXT_H_ +#ifndef ONEFLOW_CORE_VM_LAZY_JOB_DEVICE_CONTEXT_H_ +#define ONEFLOW_CORE_VM_LAZY_JOB_DEVICE_CONTEXT_H_ #include "oneflow/core/framework/nn_graph_if.h" #include "oneflow/core/common/util.h" @@ -93,4 +93,4 @@ class LazyJobDeviceCtx final : public DeviceCtx { } // namespace vm } // namespace oneflow -#endif // ONEFLOW_CORE_EAGER_LAZY_JOB_DEVICE_CONTEXT_H_ +#endif // ONEFLOW_CORE_VM_LAZY_JOB_DEVICE_CONTEXT_H_ diff --git a/oneflow/core/eager/lazy_job_stream_type.cpp b/oneflow/core/vm/lazy_job_stream_type.cpp similarity index 75% rename from oneflow/core/eager/lazy_job_stream_type.cpp rename to oneflow/core/vm/lazy_job_stream_type.cpp index b34a2f03924..2d5720dd83c 100644 --- a/oneflow/core/eager/lazy_job_stream_type.cpp +++ b/oneflow/core/vm/lazy_job_stream_type.cpp @@ -14,11 +14,11 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "oneflow/core/eager/lazy_job_stream_type.h" +#include "oneflow/core/vm/lazy_job_stream_type.h" #include "oneflow/core/vm/instruction_type.h" #include "oneflow/core/vm/instruction.h" #include "oneflow/core/vm/thread_ctx.h" -#include "oneflow/core/eager/lazy_job_device_context.h" +#include "oneflow/core/vm/lazy_job_device_context.h" #include "oneflow/core/vm/naive_instruction_status_querier.h" #include "oneflow/core/common/util.h" @@ -48,19 +48,7 @@ bool LazyJobStreamType::QueryInstructionStatusDone( } void LazyJobStreamType::Compute(Instruction* instruction) const { - { - const auto& instr_type_id = instruction->mut_instr_msg()->instr_type_id(); - instr_type_id.instruction_type().Compute(instruction); - } -} - -intrusive::shared_ptr LazyJobStreamType::MakeStreamDesc(const Resource& resource, - int64_t this_machine_id) const { - auto ret = intrusive::make_shared(); - ret->set_stream_type(StaticGlobalStreamType()); - ret->set_num_streams_per_machine(1); - ret->set_num_streams_per_thread(1); - return ret; + instruction->instr_msg().instruction_type().Compute(instruction); } } // namespace vm diff --git a/oneflow/core/eager/lazy_job_stream_type.h b/oneflow/core/vm/lazy_job_stream_type.h similarity index 81% rename from oneflow/core/eager/lazy_job_stream_type.h rename to oneflow/core/vm/lazy_job_stream_type.h index 10cad9c2eaf..dd2196c7347 100644 --- a/oneflow/core/eager/lazy_job_stream_type.h +++ b/oneflow/core/vm/lazy_job_stream_type.h @@ -14,8 +14,8 @@ See the License for the specific language governing permissions and limitations under the License. */ -#ifndef ONEFLOW_CORE_EAGER_LAZY_JOB_STREAM_TYPE_H_ -#define ONEFLOW_CORE_EAGER_LAZY_JOB_STREAM_TYPE_H_ +#ifndef ONEFLOW_CORE_VM_LAZY_JOB_STREAM_TYPE_H_ +#define ONEFLOW_CORE_VM_LAZY_JOB_STREAM_TYPE_H_ #include "oneflow/core/intrusive/flat_msg_view.h" #include "oneflow/core/vm/stream_type.h" @@ -31,8 +31,6 @@ class LazyJobStreamType final : public StreamType { LazyJobStreamType() = default; virtual ~LazyJobStreamType() = default; - const char* stream_tag() const override { return "lazy_job"; } - void InitDeviceCtx(std::unique_ptr* device_ctx, Stream* stream) const override; void InitInstructionStatus(const Stream& stream, @@ -44,11 +42,9 @@ class LazyJobStreamType final : public StreamType { void Compute(Instruction* instruction) const override; bool OnSchedulerThread() const override { return false; } bool SupportingTransportInstructions() const override { return false; } - intrusive::shared_ptr MakeStreamDesc(const Resource& resource, - int64_t this_machine_id) const override; }; } // namespace vm } // namespace oneflow -#endif // ONEFLOW_CORE_EAGER_LAZY_JOB_STREAM_TYPE_H_ +#endif // ONEFLOW_CORE_VM_LAZY_JOB_STREAM_TYPE_H_ diff --git a/oneflow/core/vm/runtime_instr_type_id.h b/oneflow/core/vm/runtime_instr_type_id.h deleted file mode 100644 index d146b853893..00000000000 --- a/oneflow/core/vm/runtime_instr_type_id.h +++ /dev/null @@ -1,52 +0,0 @@ -/* -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -#ifndef ONEFLOW_CORE_VM_RUNTIME_INSTR_TYPE_ID_H_ -#define ONEFLOW_CORE_VM_RUNTIME_INSTR_TYPE_ID_H_ - -#include "oneflow/core/vm/instr_type_id.h" -#include "oneflow/core/vm/stream_runtime_desc.h" - -namespace oneflow { -namespace vm { - -class RtInstrTypeId final { - public: - RtInstrTypeId(const RtInstrTypeId&) = default; - RtInstrTypeId(RtInstrTypeId&&) = default; - ~RtInstrTypeId() = default; - - RtInstrTypeId(const InstrTypeId& instr_type_id, StreamRtDesc* stream_rt_desc) - : instr_type_id_(instr_type_id), stream_rt_desc_(stream_rt_desc) { - if (stream_rt_desc->stream_type().IsControlStreamType()) { - get_stream_ = &StreamRtDesc::GetSoleStream; - } else { - get_stream_ = &StreamRtDesc::GetDeviceStream; - } - } - - const InstrTypeId& instr_type_id() const { return instr_type_id_; } - Stream* GetStream(int device_id) const { return (stream_rt_desc_->*get_stream_)(device_id); } - - private: - const InstrTypeId instr_type_id_; - StreamRtDesc* stream_rt_desc_; - Stream* (StreamRtDesc::*get_stream_)(int device_id) const; -}; - -} // namespace vm -} // namespace oneflow - -#endif // ONEFLOW_CORE_VM_RUNTIME_INSTR_TYPE_ID_H_ diff --git a/oneflow/core/vm/sequential_instruction_type.cpp b/oneflow/core/vm/sequential_instruction_type.cpp deleted file mode 100644 index dca5a7473e0..00000000000 --- a/oneflow/core/vm/sequential_instruction_type.cpp +++ /dev/null @@ -1,105 +0,0 @@ -/* -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -#include "oneflow/core/common/util.h" -#include "oneflow/core/intrusive/flat_msg_view.h" -#include "oneflow/core/rpc/include/base.h" -#include "oneflow/core/vm/control_stream_type.h" -#include "oneflow/core/vm/instruction_type.h" -#include "oneflow/core/vm/instruction.h" -#include "oneflow/core/vm/virtual_machine_engine.h" -#include "oneflow/core/vm/barrier_phy_instr_operand.h" -#include "oneflow/core/control/global_process_ctx.h" - -namespace oneflow { -namespace vm { - -class RankFrontSeqCallbackInstructionType : public InstructionType { - public: - RankFrontSeqCallbackInstructionType() = default; - virtual ~RankFrontSeqCallbackInstructionType() override = default; - - bool IsFrontSequential() const override { return true; } - - protected: -}; - -class ComputeRankFrontSeqCallbackInstructionType final - : public RankFrontSeqCallbackInstructionType { - public: - ComputeRankFrontSeqCallbackInstructionType() = default; - ~ComputeRankFrontSeqCallbackInstructionType() override = default; - - using stream_type = ControlStreamType; - - void Compute(Instruction* instruction) const override { - const auto* operand = instruction->instr_msg().phy_instr_operand().get(); - const auto* barrier_operand = dynamic_cast(operand); - CHECK_NOTNULL(barrier_operand)->callback(); - } - void ComputeInFuseMode(InstructionMsg* instr_msg) const override { - const auto* operand = instr_msg->phy_instr_operand().get(); - const auto* barrier_operand = dynamic_cast(operand); - CHECK_NOTNULL(barrier_operand)->callback(); - } -}; -COMMAND(RegisterInstructionType( - "ComputeRankFrontSeqCallback")); - -class CtrlComputeRankFrontSeqCallbackInstructionType final - : public RankFrontSeqCallbackInstructionType { - public: - CtrlComputeRankFrontSeqCallbackInstructionType() = default; - ~CtrlComputeRankFrontSeqCallbackInstructionType() override = default; - - using stream_type = ControlStreamType; - - void Compute(Instruction* instruction) const override { - const auto* operand = instruction->instr_msg().phy_instr_operand().get(); - const auto* barrier_operand = dynamic_cast(operand); - CHECK_NOTNULL(barrier_operand)->callback(); - } -}; -COMMAND(RegisterInstructionType( - "CtrlComputeRankFrontSeqCallback")); - -class GlobalFrontSeqBarrierInstructionType : public InstructionType { - public: - GlobalFrontSeqBarrierInstructionType() = default; - virtual ~GlobalFrontSeqBarrierInstructionType() override = default; - - using stream_type = ControlStreamType; - - virtual bool IsFrontSequential() const override { return true; } -}; - -class ComputeGlobalFrontSeqBarrierInstructionType final - : public GlobalFrontSeqBarrierInstructionType { - public: - ComputeGlobalFrontSeqBarrierInstructionType() = default; - ~ComputeGlobalFrontSeqBarrierInstructionType() override = default; - - void Compute(Instruction* instruction) const override { - OF_ENV_BARRIER(); - const auto* operand = instruction->instr_msg().phy_instr_operand().get(); - const auto* barrier_operand = dynamic_cast(operand); - CHECK_NOTNULL(barrier_operand)->callback(); - } -}; -COMMAND(RegisterInstructionType( - "ComputeGlobalFrontSeqBarrier")); - -} // namespace vm -} // namespace oneflow diff --git a/oneflow/core/vm/stream.cpp b/oneflow/core/vm/stream.cpp index 50f3ea09262..d2c7d2f055c 100644 --- a/oneflow/core/vm/stream.cpp +++ b/oneflow/core/vm/stream.cpp @@ -17,40 +17,37 @@ limitations under the License. #include "oneflow/core/vm/thread_ctx.h" #include "oneflow/core/common/util.h" #include "oneflow/core/common/cpp_attribute.h" +#include "oneflow/core/framework/device.h" +#include "oneflow/core/vm/stream_get_stream_type.h" namespace oneflow { namespace vm { -void Stream::__Init__() { clear_thread_ctx(); } - -void Stream::__Init__(ThreadCtx* thread_ctx, const StreamId& stream_id, - const int64_t max_device_num_per_machine) { - __Init__(); +void Stream::__Init__( + ThreadCtx* thread_ctx, Symbol device, StreamRole stream_role, + const intrusive::shared_ptr& schedule_local_dep_object, + const Optional>& transport_local_dep_object) { set_thread_ctx(thread_ctx); - mut_stream_id()->CopyFrom(stream_id); - // InitDeviceCtx may use max_device_num_per_machine, - // so max_device_num_per_machine must be set before InitDeviceCtx - set_max_device_num_per_machine(max_device_num_per_machine); - stream_type().InitDeviceCtx(mut_device_ctx(), this); + device_ = device; + stream_role_ = stream_role; + stream_type_ = CHECK_JUST(GetStreamType::Visit(stream_role, device->enum_type())); + stream_type_->InitDeviceCtx(mut_device_ctx(), this); + schedule_local_dep_object_ = schedule_local_dep_object; + transport_local_dep_object_ = transport_local_dep_object; } -int64_t Stream::machine_id() const { return global_device_id() / max_device_num_per_machine(); } - -int64_t Stream::device_id() const { return global_device_id() % max_device_num_per_machine(); } +int64_t Stream::device_id() const { return device_->device_id(); } -const StreamType& Stream::stream_type() const { - return thread_ctx().stream_rt_desc().stream_type(); -} +const StreamType& Stream::stream_type() const { return *stream_type_; } -intrusive::shared_ptr Stream::NewInstruction( - InstructionMsg* instr_msg, const std::shared_ptr& parallel_desc) { +intrusive::shared_ptr Stream::NewInstruction(InstructionMsg* instr_msg) { intrusive::shared_ptr instruction; if (unlikely(free_instruction_list().empty())) { instruction = intrusive::make_shared(); } else { instruction = mut_free_instruction_list()->PopFront(); } - instruction->Init(instr_msg, this, parallel_desc); + instruction->Init(instr_msg); return instruction; } diff --git a/oneflow/core/vm/stream.h b/oneflow/core/vm/stream.h index 3e1936f5b2d..d668a7d9463 100644 --- a/oneflow/core/vm/stream.h +++ b/oneflow/core/vm/stream.h @@ -16,14 +16,21 @@ limitations under the License. #ifndef ONEFLOW_CORE_VM_STREAM_H_ #define ONEFLOW_CORE_VM_STREAM_H_ -#include "oneflow/core/vm/stream_desc.h" #include "oneflow/core/vm/instruction.h" #include "oneflow/core/device/device_context.h" +#include "oneflow/core/common/symbol.h" +#include "oneflow/core/common/optional.h" +#include "oneflow/core/common/stream_role.h" namespace oneflow { + +class Device; + namespace vm { class ThreadCtx; +class StreamType; +class MirroredObject; class Stream final : public intrusive::Base { public: @@ -32,7 +39,6 @@ class Stream final : public intrusive::Base { intrusive::List; // Getters - int64_t max_device_num_per_machine() const { return max_device_num_per_machine_; } const ThreadCtx& thread_ctx() const { return *thread_ctx_; } bool has_thread_ctx() const { return thread_ctx_ != nullptr; } const std::unique_ptr& device_ctx() const { return device_ctx_; } @@ -44,10 +50,8 @@ class Stream final : public intrusive::Base { const DispatchedInstructionList& running_instruction_list() const { return running_instruction_list_; } - const StreamId& stream_id() const { return stream_id_.key(); } // Setters - void set_max_device_num_per_machine(int64_t val) { max_device_num_per_machine_ = val; } ThreadCtx* mut_thread_ctx() { return thread_ctx_; } void set_thread_ctx(ThreadCtx* val) { thread_ctx_ = val; } void clear_thread_ctx() { thread_ctx_ = nullptr; } @@ -55,20 +59,26 @@ class Stream final : public intrusive::Base { DispatchedInstructionList* mut_free_instruction_list() { return &free_instruction_list_; } DispatchedInstructionList* mut_zombie_instruction_list() { return &zombie_instruction_list_; } DispatchedInstructionList* mut_running_instruction_list() { return &running_instruction_list_; } - StreamId* mut_stream_id() { return stream_id_.mut_key(); } // methods - void __Init__(); - void __Init__(ThreadCtx* thread_ctx, const StreamId& stream_id, - const int64_t max_device_num_per_machine); - intrusive::shared_ptr NewInstruction( - InstructionMsg* instr_msg, const std::shared_ptr& parallel_desc); + void __Init__(ThreadCtx* thread_ctx, Symbol device, StreamRole stream_role, + const intrusive::shared_ptr& schedule_local_dep_object, + const Optional>& transport_local_dep_object); + intrusive::shared_ptr NewInstruction(InstructionMsg* instr_msg); void DeleteInstruction(intrusive::shared_ptr&&); - int64_t global_device_id() const { return stream_id().global_device_id(); } - int64_t machine_id() const; int64_t device_id() const; + Symbol device() const { return device_; } + StreamRole stream_role() const { return stream_role_; } const StreamType& stream_type() const; + const intrusive::shared_ptr& schedule_local_dep_object() const { + return schedule_local_dep_object_; + } + + const Optional>& transport_local_dep_object() const { + return transport_local_dep_object_; + } + private: void MoveToFreeList(intrusive::shared_ptr&& instruction); void MoveFromZombieListToFreeList(); @@ -79,27 +89,31 @@ class Stream final : public intrusive::Base { Stream() : intrusive_ref_(), thread_ctx_(), + device_(), + stream_role_(StreamRole::kInvalid), + stream_type_(), device_ctx_(), - max_device_num_per_machine_(), free_instruction_list_(), zombie_instruction_list_(), running_instruction_list_(), - stream_id_(), active_stream_hook_(), thread_ctx_stream_hook_() {} intrusive::Ref intrusive_ref_; // fields ThreadCtx* thread_ctx_; + Symbol device_; + StreamRole stream_role_; + const StreamType* stream_type_; std::unique_ptr device_ctx_; - int64_t max_device_num_per_machine_; // lists DispatchedInstructionList free_instruction_list_; DispatchedInstructionList zombie_instruction_list_; DispatchedInstructionList running_instruction_list_; + intrusive::shared_ptr schedule_local_dep_object_; + Optional> transport_local_dep_object_; + public: - // skiplist hooks - intrusive::SkipListHook stream_id_; // list hooks intrusive::ListHook active_stream_hook_; intrusive::ListHook thread_ctx_stream_hook_; diff --git a/oneflow/core/vm/stream_desc.cpp b/oneflow/core/vm/stream_desc.cpp deleted file mode 100644 index d026186d935..00000000000 --- a/oneflow/core/vm/stream_desc.cpp +++ /dev/null @@ -1,36 +0,0 @@ -/* -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -#include "oneflow/core/vm/stream_desc.h" - -namespace oneflow { -namespace vm { - -void StreamDesc::__Init__(const StreamType* stream_type, int32_t num_streams_per_machine, - int32_t num_streams_per_thread) { - set_stream_type(stream_type); - set_num_streams_per_machine(num_streams_per_machine); - set_num_streams_per_thread(num_streams_per_thread); -} - -int32_t StreamDesc::num_threads() const { - int32_t num_devices = num_streams_per_machine(); - if (num_devices == 0) { return 0; } - CHECK_EQ(num_devices % num_streams_per_thread(), 0); - return num_devices / num_streams_per_thread(); -} - -} // namespace vm -} // namespace oneflow diff --git a/oneflow/core/vm/stream_desc.h b/oneflow/core/vm/stream_desc.h deleted file mode 100644 index a996bc0dd03..00000000000 --- a/oneflow/core/vm/stream_desc.h +++ /dev/null @@ -1,99 +0,0 @@ -/* -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -#ifndef ONEFLOW_CORE_VM_VPU_DESC__H_ -#define ONEFLOW_CORE_VM_VPU_DESC__H_ - -#include -#include -#include "oneflow/core/intrusive/flat_msg.h" -#include "oneflow/core/intrusive/intrusive.h" -#include "oneflow/core/vm/id_util.h" - -namespace oneflow { -namespace vm { - -class StreamType; - -class StreamId final { - public: - using self_type = StreamId; - void __Init__() {} - void __Init__(const StreamType* stream_type, int64_t global_device_id) { - stream_type_ = stream_type; - global_device_id_ = global_device_id; - } - - void CopyFrom(const StreamId& rhs) { __Init__(rhs.stream_type_, rhs.global_device_id_); } - - const StreamType& stream_type() const { return *stream_type_; } - int64_t global_device_id() const { return global_device_id_; } - - bool operator==(const StreamId& rhs) const { - return stream_type_ == rhs.stream_type_ && global_device_id_ == rhs.global_device_id_; - } - - bool operator<(const StreamId& rhs) const { - if (!(stream_type_ == rhs.stream_type_)) { return stream_type_ < rhs.stream_type_; } - return global_device_id_ < rhs.global_device_id_; - } - bool operator<=(const StreamId& rhs) const { return *this < rhs || *this == rhs; } - - private: - const StreamType* stream_type_; - int64_t global_device_id_; -}; - -class StreamDesc final : public intrusive::Base { - public: - // Getters - int32_t num_streams_per_machine() const { return num_streams_per_machine_; } - int32_t num_streams_per_thread() const { return num_streams_per_thread_; } - const StreamType& stream_type() const { return *stream_type_key_.key(); } - // Setters - void set_num_streams_per_machine(int32_t val) { num_streams_per_machine_ = val; } - void set_num_streams_per_thread(int32_t val) { num_streams_per_thread_ = val; } - void set_stream_type(const StreamType* stream_type) { *stream_type_key_.mut_key() = stream_type; } - - // methods - void __Init__() {} - void __Init__(const StreamType* stream_type, int32_t num_streams_per_machine, - int32_t num_streams_per_thread); - int32_t num_threads() const; - int32_t parallel_num() const { return num_streams_per_machine(); } - - private: - friend class intrusive::Ref; - intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; } - - StreamDesc() - : intrusive_ref_(), - num_streams_per_machine_(), - num_streams_per_thread_(), - stream_type_key_() {} - intrusive::Ref intrusive_ref_; - // fields - int32_t num_streams_per_machine_; - int32_t num_streams_per_thread_; - - public: - // skiplist hooks - intrusive::SkipListHook stream_type_key_; -}; - -} // namespace vm -} // namespace oneflow - -#endif // ONEFLOW_CORE_VM_VPU_DESC__H_ diff --git a/oneflow/core/vm/stream_get_stream_type.h b/oneflow/core/vm/stream_get_stream_type.h new file mode 100644 index 00000000000..2eb1d6ca879 --- /dev/null +++ b/oneflow/core/vm/stream_get_stream_type.h @@ -0,0 +1,108 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_CORE_VM_STREAM_GET_STREAM_TYPE_H_ +#define ONEFLOW_CORE_VM_STREAM_GET_STREAM_TYPE_H_ + +#include "oneflow/core/common/stream_role.h" +#include "oneflow/core/common/singleton_ptr.h" +#include "oneflow/core/vm/event_recorded_cuda_stream_type.h" +#include "oneflow/core/vm/control_stream_type.h" +#include "oneflow/core/vm/cpu_stream_type.h" +#include "oneflow/core/vm/critical_section_stream_type.h" +#include "oneflow/core/vm/cuda_copy_d2h_stream_type.h" +#include "oneflow/core/vm/cuda_copy_h2d_stream_type.h" +#include "oneflow/core/vm/cuda_stream_type.h" +#include "oneflow/core/vm/lazy_job_stream_type.h" +#include "oneflow/core/vm/stream_get_stream_type.h" + +namespace oneflow { + +struct GetStreamType final : public StreamRoleVisitor { + static Maybe VisitCompute(DeviceType device_type) { + if (device_type == DeviceType::kCPU) { + return SingletonPtr(); + } else if (device_type == DeviceType::kCUDA) { +#ifdef WITH_CUDA + return SingletonPtr(); +#else + UNIMPLEMENTED_THEN_RETURN(); +#endif + } else { + UNIMPLEMENTED_THEN_RETURN(); + } + } + static Maybe VisitHost2Device(DeviceType device_type) { + if (device_type == DeviceType::kCUDA) { +#ifdef WITH_CUDA + return SingletonPtr(); +#else + UNIMPLEMENTED_THEN_RETURN(); +#endif + } else { + UNIMPLEMENTED_THEN_RETURN(); + } + } + static Maybe VisitDevice2Host(DeviceType device_type) { + if (device_type == DeviceType::kCUDA) { +#ifdef WITH_CUDA + return SingletonPtr(); +#else + UNIMPLEMENTED_THEN_RETURN(); +#endif + } else { + UNIMPLEMENTED_THEN_RETURN(); + } + } + static Maybe VisitSyncedLaunchedCommNet(DeviceType device_type) { + if (device_type == DeviceType::kCPU) { + return SingletonPtr(); + } else if (device_type == DeviceType::kCUDA) { +#ifdef WITH_CUDA + return SingletonPtr(); +#else + UNIMPLEMENTED_THEN_RETURN(); +#endif + } else { + UNIMPLEMENTED_THEN_RETURN(); + } + } + static Maybe VisitAsyncedLaunchedCommNet(DeviceType device_type) { + if (device_type == DeviceType::kCPU) { + return SingletonPtr(); + } else if (device_type == DeviceType::kCUDA) { +#ifdef WITH_CUDA + return SingletonPtr(); +#else + UNIMPLEMENTED_THEN_RETURN(); +#endif + } else { + UNIMPLEMENTED_THEN_RETURN(); + } + } + static Maybe VisitBarrier(DeviceType device_type) { + return SingletonPtr(); + } + static Maybe VisitCriticalSection(DeviceType device_type) { + return SingletonPtr(); + } + static Maybe VisitLazyJobLauncher(DeviceType device_type) { + return SingletonPtr(); + } +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_VM_STREAM_GET_STREAM_TYPE_H_ diff --git a/oneflow/core/vm/stream_runtime_desc.h b/oneflow/core/vm/stream_runtime_desc.h deleted file mode 100644 index 6e7aa400c55..00000000000 --- a/oneflow/core/vm/stream_runtime_desc.h +++ /dev/null @@ -1,85 +0,0 @@ -/* -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -#ifndef ONEFLOW_CORE_VM_STREAM_RUNTIME_DESC__H_ -#define ONEFLOW_CORE_VM_STREAM_RUNTIME_DESC__H_ - -#include "oneflow/core/vm/stream_desc.h" -#include "oneflow/core/vm/stream.h" - -namespace oneflow { -namespace vm { - -class StreamType; -class StreamDesc; - -// Rt is short for Runtime -class StreamRtDesc final : public intrusive::Base { - public: - // Getters - const StreamDesc& stream_desc() const { - if (stream_desc_) { return stream_desc_.Get(); } - static const auto default_val = intrusive::make_shared(); - return default_val.Get(); - } - const StreamType& stream_type() const { return *stream_type_key_.key(); } - const std::vector>& device_id2stream() const { - return device_id2stream_; - } - - // The value of `device_id` is ignored. - Stream* GetSoleStream(int device_id) const { return GetSoleStream(); } - Stream* GetSoleStream() const { - CHECK_EQ(device_id2stream().size(), 1); - return device_id2stream().at(0).get(); - } - - Stream* GetDeviceStream(int device_id) const { return device_id2stream().at(device_id).get(); } - - // Setters - StreamDesc* mut_stream_desc() { - if (!stream_desc_) { stream_desc_ = intrusive::make_shared(); } - return stream_desc_.Mutable(); - } - void reset_stream_desc(StreamDesc* stream_desc) { stream_desc_.Reset(stream_desc); } - void set_stream_type(const StreamType* stream_type) { *stream_type_key_.mut_key() = stream_type; } - void add_stream(intrusive::shared_ptr stream) { - CHECK_EQ(stream->device_id(), device_id2stream_.size()); - device_id2stream_.emplace_back(stream); - } - - // methods - void __Init__(StreamDesc* stream_desc); - - private: - friend class intrusive::Ref; - intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; } - - StreamRtDesc() : intrusive_ref_(), stream_desc_(), device_id2stream_(), stream_type_key_() {} - intrusive::Ref intrusive_ref_; - // fields - intrusive::shared_ptr stream_desc_; - // containers - std::vector> device_id2stream_; - - public: - // skiplist hooks - intrusive::SkipListHook stream_type_key_; -}; - -} // namespace vm -} // namespace oneflow - -#endif // ONEFLOW_CORE_VM_STREAM_RUNTIME_DESC__H_ diff --git a/oneflow/core/vm/stream_type.h b/oneflow/core/vm/stream_type.h index 8fee7b6054d..0a8868dddc4 100644 --- a/oneflow/core/vm/stream_type.h +++ b/oneflow/core/vm/stream_type.h @@ -19,8 +19,6 @@ limitations under the License. #include #include #include -#include "oneflow/core/vm/stream_desc.h" -#include "oneflow/core/vm/instr_type_id.h" #include "oneflow/core/device/device_context.h" #include "oneflow/core/job/resource.pb.h" @@ -40,8 +38,6 @@ class StreamType { void Run(Instruction* instruction) const { Compute(instruction); } - virtual const char* stream_tag() const = 0; - virtual void InitDeviceCtx(std::unique_ptr* device_ctx, Stream* stream) const = 0; virtual void InitInstructionStatus(const Stream& stream, @@ -52,9 +48,6 @@ class StreamType { const InstructionStatusBuffer& status_buffer) const = 0; virtual void Compute(Instruction* instruction) const = 0; - virtual intrusive::shared_ptr MakeStreamDesc(const Resource& resource, - int64_t this_machine_id) const = 0; - virtual bool OnSchedulerThread() const = 0; virtual bool SupportingTransportInstructions() const = 0; virtual bool IsControlStreamType() const { return false; } diff --git a/oneflow/core/vm/thread_ctx.cpp b/oneflow/core/vm/thread_ctx.cpp index c347fa1d9ed..f91e52867b3 100644 --- a/oneflow/core/vm/thread_ctx.cpp +++ b/oneflow/core/vm/thread_ctx.cpp @@ -20,12 +20,12 @@ namespace oneflow { namespace vm { size_t ThreadCtx::TryReceiveAndRun() { - const StreamType& stream_type = stream_rt_desc().stream_type(); intrusive::List tmp_list; mut_pending_instruction_list()->MoveTo(&tmp_list); size_t size = tmp_list.size(); INTRUSIVE_FOR_EACH(instruction, &tmp_list) { tmp_list.Erase(instruction.Mutable()); + const StreamType& stream_type = instruction->stream().stream_type(); stream_type.Run(instruction.Mutable()); } return size; diff --git a/oneflow/core/vm/thread_ctx.h b/oneflow/core/vm/thread_ctx.h index 150b09f29fc..31d64d8aae8 100644 --- a/oneflow/core/vm/thread_ctx.h +++ b/oneflow/core/vm/thread_ctx.h @@ -21,41 +21,28 @@ limitations under the License. #include "oneflow/core/intrusive/mutexed_list.h" #include "oneflow/core/common/notifier.h" #include "oneflow/core/vm/stream.h" -#include "oneflow/core/vm/stream_runtime_desc.h" namespace oneflow { namespace vm { using PendingInstructionMutexedList = intrusive::MutexedList; -using PendingInstructionList = - intrusive::List; class ThreadCtx final : public intrusive::Base { public: - void __Init__() { clear_stream_rt_desc(); } - // types using StreamList = intrusive::List; // Getters - bool has_stream_rt_desc() const { return stream_rt_desc_ != nullptr; } - const StreamRtDesc& stream_rt_desc() const { return *stream_rt_desc_; } const StreamList& stream_list() const { return stream_list_; } // Setters - void set_stream_rt_desc(const StreamRtDesc* val) { stream_rt_desc_ = val; } - void clear_stream_rt_desc() { stream_rt_desc_ = nullptr; } StreamList* mut_stream_list() { return &stream_list_; } PendingInstructionMutexedList* mut_pending_instruction_list() { return &pending_instruction_list_; } // methods - void __Init__(const StreamRtDesc& stream_rt_desc) { - __Init__(); - set_stream_rt_desc(&stream_rt_desc); - } size_t TryReceiveAndRun(); Notifier* mut_notifier() { return ¬ifier_; } @@ -66,14 +53,12 @@ class ThreadCtx final : public intrusive::Base { ThreadCtx() : intrusive_ref_(), - stream_rt_desc_(), stream_list_(), pending_instruction_mutex_(), pending_instruction_list_(&pending_instruction_mutex_), + notifier_(), thread_ctx_hook_() {} intrusive::Ref intrusive_ref_; - // fields - const StreamRtDesc* stream_rt_desc_; // lists StreamList stream_list_; std::mutex pending_instruction_mutex_; diff --git a/oneflow/core/vm/virtual_machine.cpp b/oneflow/core/vm/virtual_machine.cpp index 6527f8c92b2..fb712e6f255 100644 --- a/oneflow/core/vm/virtual_machine.cpp +++ b/oneflow/core/vm/virtual_machine.cpp @@ -18,18 +18,27 @@ limitations under the License. #include "oneflow/core/vm/instruction.h" #include "oneflow/core/vm/instruction_type.h" #include "oneflow/core/vm/barrier_phy_instr_operand.h" +#include "oneflow/core/vm/barrier_instruction_type.h" +#include "oneflow/core/vm/barrier_phy_instr_operand.h" #include "oneflow/core/vm/vm_util.h" #include "oneflow/core/common/blocking_counter.h" #include "oneflow/core/common/cpp_attribute.h" +#include "oneflow/core/common/singleton_ptr.h" #include "oneflow/core/control/global_process_ctx.h" #include "oneflow/core/job/global_for.h" #include "oneflow/core/common/foreign_lock_helper.h" #include "oneflow/core/thread/thread_consistent_id.h" #include "oneflow/core/framework/transport_token.h" +#include "oneflow/core/framework/to_string.h" +#include "oneflow/core/framework/stream_on_independent_thread.h" +#include "oneflow/core/framework/stream_is_comm_net_stream.h" #include "oneflow/core/profiler/profiler.h" #include "oneflow/core/platform/include/pthread_fork.h" #include "oneflow/core/common/env_var/env_var.h" +#include "oneflow/core/common/container_util.h" #include "oneflow/core/framework/device.h" +#include "oneflow/core/framework/stream.h" +#include "oneflow/core/framework/stream_mgr.h" namespace oneflow { @@ -42,11 +51,9 @@ int MicrosecondsFrom(const T& start) { .count(); } -Maybe ForEachThreadCtx(vm::VirtualMachineEngine* vm, +Maybe ForEachThreadCtx(vm::VirtualMachineEngine* engine, const std::function(vm::ThreadCtx*)>& DoEach) { - INTRUSIVE_UNSAFE_FOR_EACH_PTR(thread_ctx, vm->mut_thread_ctx_list()) { - const auto& stream_type = thread_ctx->stream_rt_desc().stream_type(); - if (stream_type.OnSchedulerThread()) { continue; } + INTRUSIVE_UNSAFE_FOR_EACH_PTR(thread_ctx, engine->mut_thread_ctx_list()) { JUST(DoEach(thread_ctx)); } return Maybe::Ok(); @@ -59,45 +66,6 @@ void GetSchedulerThreadInitializer(std::function* Initializer) { }; } -std::type_index GetStreamTypeIndex(const vm::ThreadCtx* thread_ctx) { - const auto& stream_rt_desc = thread_ctx->stream_rt_desc(); - const auto& stream_type = stream_rt_desc.stream_type(); - return typeid(stream_type); -} - -// Threads with the same stream_type share a thread_consistent_id. -// e.g. -// Given there are 8 gpu thread in a single process. -// thread #0 is active in process #0, while others are not. -// thread #1 is active in process #1, while others are not. -// ... -// thread #7 is active in process #7, while others are not. -// to make them communicate with each other, we can allocate thread_consistent_id 1 to all those -// gpu threads in all processes. -void GetWorkerThreadInitializer(intrusive::shared_ptr vm, - std::function* Initializer) { - std::set stream_type_indexes; - INTRUSIVE_UNSAFE_FOR_EACH_PTR(thread_ctx, vm->mut_thread_ctx_list()) { - const auto& stream_type = thread_ctx->stream_rt_desc().stream_type(); - if (!stream_type.SupportingTransportInstructions()) { continue; } - stream_type_indexes.insert(GetStreamTypeIndex(thread_ctx)); - } - HashMap stream_type_index2consistent_id; - int64_t thread_consistent_id = kThreadConsistentIdScheduler + 1; - for (const auto& stream_type_index : stream_type_indexes) { - VLOG(3) << "transport stream type: " << stream_type_index.name(); - stream_type_index2consistent_id[stream_type_index] = thread_consistent_id++; - } - *Initializer = [stream_type_index2consistent_id](vm::ThreadCtx* thread_ctx) { - const auto& stream_type_index = GetStreamTypeIndex(thread_ctx); - const auto& iter = stream_type_index2consistent_id.find(stream_type_index); - if (iter != stream_type_index2consistent_id.end()) { - CHECK_JUST(InitThisThreadConsistentId(iter->second, stream_type_index.name())); - } - OF_PROFILER_NAME_THIS_HOST_THREAD("_VM::Worker"); - }; -} - void WorkerLoop(vm::ThreadCtx* thread_ctx, const std::function& Initializer) { Initializer(thread_ctx); while (thread_ctx->mut_notifier()->WaitAndClearNotifiedCnt() == kNotifierStatusSuccess) { @@ -107,36 +75,45 @@ void WorkerLoop(vm::ThreadCtx* thread_ctx, const std::function( - vm::MakeVmDesc(resource, this_machine_id).Get()); + engine_ = intrusive::make_shared(); OF_PROFILER_NAME_THIS_HOST_THREAD("_Main"); - std::function WorkerInitializer; - GetWorkerThreadInitializer(vm_, &WorkerInitializer); - CHECK_JUST(ForEachThreadCtx(vm_.Mutable(), [&](vm::ThreadCtx* thread_ctx) -> Maybe { - auto thread = std::make_unique(&WorkerLoop, thread_ctx, WorkerInitializer); - worker_threads_.push_back(std::move(thread)); - return Maybe::Ok(); - })); std::function SchedulerInitializer; GetSchedulerThreadInitializer(&SchedulerInitializer); schedule_thread_ = std::thread(&VirtualMachine::ScheduleLoop, this, SchedulerInitializer); + transport_local_dep_object_.Reset(); } namespace { -void MakeCtrlSeqInstructions(vm::VirtualMachineEngine* vm, vm::InstructionMsgList* list, - const std::function& ComputeCallback) { - const auto& phy_instr_operand = std::make_shared(ComputeCallback); - auto instruction = intrusive::make_shared( - vm, "CtrlComputeRankFrontSeqCallback", std::shared_ptr(), - phy_instr_operand); - list->EmplaceBack(std::move(instruction)); +Maybe> GetBarrierStream() { + auto device = JUST(Device::New("cpu")); + return Stream::New(device, StreamRole::kBarrier); +} + +void MakeBarrierInstructions(vm::InstructionMsgList* list, + const std::function& BarrierCallback) { + auto* vm = Global::Get(); + { + const auto& phy_instr_operand = std::make_shared([]() {}); + auto stream = CHECK_JUST(GetBarrierStream()); + auto instruction = intrusive::make_shared( + CHECK_JUST(vm->GetVmStream(stream)), SingletonPtr(), + phy_instr_operand); + list->EmplaceBack(std::move(instruction)); + } + { + const auto& phy_instr_operand = std::make_shared(BarrierCallback); + auto stream = CHECK_JUST(GetBarrierStream()); + auto instruction = intrusive::make_shared( + CHECK_JUST(vm->GetVmStream(stream)), SingletonPtr(), + phy_instr_operand); + list->EmplaceBack(std::move(instruction)); + } } } // namespace @@ -144,30 +121,30 @@ void MakeCtrlSeqInstructions(vm::VirtualMachineEngine* vm, vm::InstructionMsgLis void VirtualMachine::ControlSync() { auto bc = std::make_shared(1); vm::InstructionMsgList list; - MakeCtrlSeqInstructions(mut_vm(), &list, [bc] { bc->Decrease(); }); + MakeBarrierInstructions(&list, [bc] { bc->Decrease(); }); CHECK_JUST(Receive(&list)); CHECK_JUST(bc->WaitUntilCntEqualZero(VirtualMachine::GetPredicatorNoMoreInstructionsFinished())); } Maybe VirtualMachine::CloseVMThreads() { - CHECK_OR_RETURN(!vm_threads_closed_); + CHECK_OR_RETURN(!disable_vm_threads_) << "vm threads closed"; ControlSync(); pending_notifier_.Close(); schedule_thread_.join(); - vm_threads_closed_ = true; + disable_vm_threads_ = true; return Maybe::Ok(); } VirtualMachine::~VirtualMachine() { - if (!vm_threads_closed_) { CHECK_JUST(CloseVMThreads()); } - CHECK(vm_->SchedulerEmpty()); - vm_.Reset(); + if (!disable_vm_threads_) { CHECK_JUST(CloseVMThreads()); } + CHECK(engine_->SchedulerEmpty()); + engine_.Reset(); } std::function()> VirtualMachine::GetPredicatorNoMoreInstructionsFinished() { auto last_total_erased = std::make_shared(0); auto* vm = Global::Get(); - if (vm != nullptr) { *last_total_erased = vm->vm().total_erased_instruction_cnt(); } + if (vm != nullptr) { *last_total_erased = vm->engine_->total_erased_instruction_cnt(); } return [last_total_erased]() -> Maybe { auto* vm = Global::Get(); CHECK_NOTNULL_OR_RETURN(vm) << "virtual machine not initialized."; @@ -179,7 +156,7 @@ std::function()> VirtualMachine::GetPredicatorNoMoreInstructionsFini } bool VirtualMachine::NoMoreErasedInstructions(size_t* last_total_erased_instruction_cnt) const { - size_t cnt = vm_->total_erased_instruction_cnt(); + size_t cnt = engine_->total_erased_instruction_cnt(); bool no_more_erased = (*last_total_erased_instruction_cnt == cnt); *last_total_erased_instruction_cnt = cnt; return no_more_erased; @@ -187,29 +164,29 @@ bool VirtualMachine::NoMoreErasedInstructions(size_t* last_total_erased_instruct std::string VirtualMachine::GetBlockingDebugString() { size_t limit = EnvInteger(); - return vm_->GetLivelyInstructionListDebugString(limit); + return engine_->GetLivelyInstructionListDebugString(limit); } Maybe VirtualMachine::Receive(vm::InstructionMsgList* instr_list) { if (unlikely(pthread_fork::IsForkedSubProcess())) { INTRUSIVE_FOR_EACH_PTR(instr_msg, instr_list) { - const auto& parallel_desc = instr_msg->phy_instr_parallel_desc(); - CHECK_OR_RETURN(!parallel_desc || parallel_desc->device_type() == DeviceType::kCPU) + const auto& device = instr_msg->stream().device(); + CHECK_OR_RETURN(device->enum_type() == DeviceType::kCPU) << pthread_fork::kOfCudaNotSupportInForkedSubProcess; - // NOTE: operate `vm_` in forked subprocesses causes mysterious problems. + // NOTE: operate `engine_` in forked subprocesses causes mysterious problems. // `ComputeInFuseMode` will be replaced by `Compute` soon. - instr_msg->mut_instr_type_id()->instruction_type().ComputeInFuseMode(instr_msg); + instr_msg->instruction_type().ComputeInFuseMode(instr_msg); } - } else if (unlikely(vm_threads_closed_)) { + } else if (unlikely(disable_vm_threads_)) { JUST(RunInCurrentThread(instr_list)); } else { const int64_t kHighWaterMark = GetInstructionHighWaterMark(); - if (vm_->flying_instruction_cnt() > kHighWaterMark) { + if (engine_->flying_instruction_cnt() > kHighWaterMark) { JUST(Global::Get()->WithScopedRelease([&, this]() -> Maybe { auto bc = std::make_shared(1); - vm_->InsertProbe([bc](vm::VirtualMachineEngine* vm) { + engine_->InsertProbe([bc](vm::VirtualMachineEngine* engine) { const int64_t kLowWaterMark = GetInstructionLowWaterMark(); - if (vm->flying_instruction_cnt() > kLowWaterMark) { return false; } + if (engine->flying_instruction_cnt() > kLowWaterMark) { return false; } bc->Decrease(); return true; }); @@ -218,7 +195,7 @@ Maybe VirtualMachine::Receive(vm::InstructionMsgList* instr_list) { return Maybe::Ok(); })); } - if (JUST(vm_->Receive(instr_list))) { + if (JUST(engine_->Receive(instr_list))) { // old pending_instruction_list is empty. pending_notifier_.Notify(); } @@ -238,16 +215,26 @@ class SingleThreadScheduleCtx : public vm::ScheduleCtx { } }; -void ScheduleUntilVMEmpty(vm::VirtualMachineEngine* vm, const vm::ScheduleCtx& schedule_ctx) { - do { vm->Schedule(schedule_ctx); } while (!(vm->SchedulerEmpty())); +void ScheduleUntilVMEmpty(vm::VirtualMachineEngine* engine, const vm::ScheduleCtx& schedule_ctx) { + do { engine->Schedule(schedule_ctx); } while (!(engine->SchedulerEmpty())); } } // namespace +Maybe VirtualMachine::NotifyOrRunScheduler() { + if (unlikely(pthread_fork::IsForkedSubProcess() || disable_vm_threads_)) { + ScheduleUntilVMEmpty(engine_.Mutable(), SingleThreadScheduleCtx()); + } else { + pending_notifier_.Notify(); + } + return Maybe::Ok(); +} + Maybe VirtualMachine::RunInCurrentThread(vm::InstructionMsgList* instr_list) { - CHECK_OR_RETURN(vm_->SchedulerEmpty()) << "vm scheduler not empty. May be a fatal error occured"; - JUST(vm_->Receive(instr_list)); - ScheduleUntilVMEmpty(vm_.Mutable(), SingleThreadScheduleCtx()); + CHECK_OR_RETURN(engine_->SchedulerEmpty()) + << "vm scheduler not empty. May be a fatal error occured"; + JUST(engine_->Receive(instr_list)); + ScheduleUntilVMEmpty(engine_.Mutable(), SingleThreadScheduleCtx()); return Maybe::Ok(); } @@ -268,17 +255,16 @@ class MultiThreadScheduleCtx : public vm::ScheduleCtx { void VirtualMachine::ScheduleLoop(const std::function& Initializer) { Initializer(); MultiThreadScheduleCtx schedule_ctx{}; - auto* vm = mut_vm(); while (pending_notifier_.WaitAndClearNotifiedCnt() == kNotifierStatusSuccess) { OF_PROFILER_RANGE_GUARD("VirtualMachine::ScheduleLoop"); auto start = std::chrono::steady_clock::now(); static constexpr int kWorkingMicroseconds = 1000; - // Every time this thread wakes up, vm is scheduled for about `kWorkingMicroseconds`. + // Every time this thread wakes up, engine_ is scheduled for about `kWorkingMicroseconds`. // The cost of os thread switching is about 5-10 microseconds. Doing more scheduling in // a single waiting up can reach higher performance. do { static constexpr int kNumSchedulingPerTimoutTest = 10000; - // Every time kWorkingMicroseconds timeout tested, vm is scheduled for about + // Every time kWorkingMicroseconds timeout tested, engine_ is scheduled for about // kNumSchedulingPerTimoutTest. // The cost of `MicrosecondsFrom(start)` is about 400ns, while the empty scheduling costs // about 10ns. @@ -287,24 +273,146 @@ void VirtualMachine::ScheduleLoop(const std::function& Initializer) { // Use SchedulerThreadUnsafeEmpty to avoid acquiring mutex lock. // It's safe to use SchedulerThreadUnsafeEmpty here. pending_notifier_.notified_cnt_ will be // greater than zero when inconsistency between - // vm->pending_msg_list.list_head_.list_head_.container_ and - // vm->pending_msg_list.list_head_.list_head_.size_ occured. hence the pending + // engine_->pending_msg_list.list_head_.list_head_.container_ and + // engine_->pending_msg_list.list_head_.list_head_.size_ occured. hence the pending // instructions // will get handled in the next iteration. // VirtualMachine::Receive may be less effiencient if the thread safe version - // `vm->SchedulerEmpty()` + // `engine_->SchedulerEmpty()` // used // here, because VirtualMachine::ScheduleLoop is more likely to get the mutex lock. - do { vm->Schedule(schedule_ctx); } while (!vm->SchedulerThreadUnsafeEmpty()); + do { engine_->Schedule(schedule_ctx); } while (!engine_->SchedulerThreadUnsafeEmpty()); } while (++i < kNumSchedulingPerTimoutTest); } while (MicrosecondsFrom(start) < kWorkingMicroseconds); } - ScheduleUntilVMEmpty(vm, schedule_ctx); - CHECK_JUST(ForEachThreadCtx(vm_.Mutable(), [&](vm::ThreadCtx* thread_ctx) -> Maybe { + ScheduleUntilVMEmpty(engine_.Mutable(), schedule_ctx); + CHECK_JUST(ForEachThreadCtx(engine_.Mutable(), [&](vm::ThreadCtx* thread_ctx) -> Maybe { thread_ctx->mut_notifier()->Close(); return Maybe::Ok(); })); - for (const auto& worker_thread : worker_threads_) { worker_thread->join(); } + { + std::unique_lock lock(worker_threads_mutex_); + for (const auto& worker_thread : worker_threads_) { worker_thread->join(); } + } + scheduler_stopped_ = true; +} + +intrusive::shared_ptr VirtualMachine::FindOrCreateScheduleLocalDepObject( + Symbol device, StreamRole stream_role) { + std::unique_lock lock(creating_stream_and_thread_ctx_mutex_); + auto key = std::make_pair(device, stream_role); + intrusive::shared_ptr* ptr = &device_stream_role2local_dep_object_[key]; + if (!*ptr) { *ptr = intrusive::make_shared(); } + return *ptr; +} + +intrusive::shared_ptr VirtualMachine::FindOrCreateTransportLocalDepObject() { + std::unique_lock lock(creating_stream_and_thread_ctx_mutex_); + if (!transport_local_dep_object_) { + transport_local_dep_object_ = intrusive::make_shared(); + } + return transport_local_dep_object_; +} + +Maybe VirtualMachine::CreateStream(Symbol device, StreamRole stream_role) { + std::unique_lock lock(creating_stream_and_thread_ctx_mutex_); + vm::ThreadCtx* thread_ctx = JUST(FindOrCreateThreadCtx(device, stream_role)); + return JUST(CreateStream(thread_ctx, device, stream_role)); +} + +Maybe VirtualMachine::GetVmStream(Symbol stream) { + if (stream->unique_stream_id() >= unique_stream_id2vm_stream_.size()) { + std::unique_lock lock(creating_stream_and_thread_ctx_mutex_); + if (stream->unique_stream_id() >= unique_stream_id2vm_stream_.size()) { + auto* stream_mgr = JUST(GlobalMaybe()); + for (int i = unique_stream_id2vm_stream_.size(); i <= stream->unique_stream_id(); ++i) { + Symbol cur_stream = JUST(stream_mgr->GetStreamSymbol(i)); + CHECK_EQ_OR_RETURN(cur_stream->unique_stream_id(), i) + << "invalid Stream::unique_stream_id()"; + *unique_stream_id2vm_stream_.MutableOrAdd(cur_stream->unique_stream_id()) = + JUST(CreateStream(cur_stream->device(), cur_stream->stream_role())); + } + } + } + return JUST(VectorAt(unique_stream_id2vm_stream_, stream->unique_stream_id())); +} + +Maybe VirtualMachine::FindOrCreateThreadCtx(Symbol device, + StreamRole stream_role) { + std::unique_lock lock(creating_stream_and_thread_ctx_mutex_); + vm::ThreadCtx** thread_ctx_ptr = nullptr; + if (StreamOnIndependentThread::Visit(stream_role)) { + auto key = std::make_pair(device->enum_type(), stream_role); + thread_ctx_ptr = &devcie_type_stream_role_2independent_thread_ctx_[key]; + } else { + thread_ctx_ptr = &devcie_type2non_independent_thread_ctx_[device->enum_type()]; + } + if (*thread_ctx_ptr == nullptr) { *thread_ctx_ptr = JUST(CreateThreadCtx(device, stream_role)); } + return *thread_ctx_ptr; +} + +Maybe VirtualMachine::CreateThreadCtx(Symbol device, + StreamRole stream_role) { + std::unique_lock lock(creating_stream_and_thread_ctx_mutex_); + // thread_ctx_ptr may be used after timout. + auto thread_ctx_ptr = std::make_shared(nullptr); + { + auto bc = std::make_shared(1); + engine_->InsertProbe([thread_ctx_ptr, bc](vm::VirtualMachineEngine* engine) { + auto thread_ctx = intrusive::make_shared(); + engine->mut_thread_ctx_list()->PushBack(thread_ctx.Mutable()); + *thread_ctx_ptr = thread_ctx.Mutable(); + bc->Decrease(); + return true; + }); + JUST(NotifyOrRunScheduler()); + JUST(bc->WaitUntilCntEqualZero(VirtualMachine::GetPredicatorNoMoreInstructionsFinished())); + } + auto* thread_ctx = *thread_ctx_ptr; + { + const auto& WorkerInitializer = [device, stream_role](vm::ThreadCtx* thread_ctx) { + int device_type_value = static_cast(device->enum_type()); + CHECK_GT(device_type_value, 0); + std::string device_tag = *CHECK_JUST(DeviceTag4DeviceType(device->enum_type())); + if (!StreamOnIndependentThread::Visit(stream_role)) { + CHECK_JUST(InitThisThreadConsistentId(device_type_value + kThreadConsistentIdScheduler, + device_tag)); + } + OF_PROFILER_NAME_THIS_HOST_THREAD("_VM::Worker_" + device_tag); + }; + auto thread = std::make_unique(&WorkerLoop, thread_ctx, WorkerInitializer); + { + std::unique_lock lock(worker_threads_mutex_); + worker_threads_.push_back(std::move(thread)); + } + } + return thread_ctx; +} + +Maybe VirtualMachine::CreateStream(vm::ThreadCtx* thread_ctx, Symbol device, + StreamRole stream_role) { + std::unique_lock lock(creating_stream_and_thread_ctx_mutex_); + // stream_ptr may be used after timout. + auto stream_ptr = std::make_shared(nullptr); + auto bc = std::make_shared(1); + intrusive::shared_ptr schedule_local_dep_object = + FindOrCreateScheduleLocalDepObject(device, stream_role); + Optional> transport_local_dep_object; + if (IsCommNetStream::Visit(stream_role)) { + transport_local_dep_object = FindOrCreateTransportLocalDepObject(); + } + engine_->InsertProbe([stream_ptr, thread_ctx, device, stream_role, bc, schedule_local_dep_object, + transport_local_dep_object](vm::VirtualMachineEngine* engine) { + auto stream = intrusive::make_shared( + thread_ctx, device, stream_role, schedule_local_dep_object, transport_local_dep_object); + thread_ctx->mut_stream_list()->PushBack(stream.Mutable()); + *stream_ptr = stream.Mutable(); + bc->Decrease(); + return true; + }); + JUST(NotifyOrRunScheduler()); + JUST(bc->WaitUntilCntEqualZero(VirtualMachine::GetPredicatorNoMoreInstructionsFinished())); + return *stream_ptr; } } // namespace oneflow diff --git a/oneflow/core/vm/virtual_machine.h b/oneflow/core/vm/virtual_machine.h index 29e17f0aa3e..2f06401b2d2 100644 --- a/oneflow/core/vm/virtual_machine.h +++ b/oneflow/core/vm/virtual_machine.h @@ -16,47 +16,79 @@ limitations under the License. #ifndef ONEFLOW_CORE_VM_VIRTUAL_MACHINE_H_ #define ONEFLOW_CORE_VM_VIRTUAL_MACHINE_H_ +#include #include "oneflow/core/common/notifier.h" -#include "oneflow/core/vm/vm_desc.h" #include "oneflow/core/vm/virtual_machine_engine.h" #include "oneflow/core/thread/thread_pool.h" +#include "oneflow/core/common/stream_role.h" +#include "oneflow/core/common/steady_vector.h" namespace oneflow { class InstructionsBuilder; +class Device; class VirtualMachine final { public: VirtualMachine(const VirtualMachine&) = delete; VirtualMachine(VirtualMachine&&) = delete; - VirtualMachine(const Resource& resource, int64_t this_machine_id); + VirtualMachine(); ~VirtualMachine(); static std::function()> GetPredicatorNoMoreInstructionsFinished(); - bool NoMoreErasedInstructions(size_t* last_total_erased_instruction_cnt) const; + intrusive::shared_ptr FindOrCreateTransportLocalDepObject(); + std::string GetBlockingDebugString(); Maybe Receive(vm::InstructionMsgList* instr_list); - const vm::VirtualMachineEngine& vm() const { return *vm_; } - Maybe CloseVMThreads(); + Maybe GetVmStream(Symbol stream); + private: friend class InstructionsBuilder; void ScheduleLoop(const std::function& Initializer); - vm::VirtualMachineEngine* mut_vm() { return vm_.Mutable(); } + intrusive::shared_ptr FindOrCreateScheduleLocalDepObject( + Symbol device, StreamRole stream_role); + bool NoMoreErasedInstructions(size_t* last_total_erased_instruction_cnt) const; + + const vm::VirtualMachineEngine& engine() const { return *engine_; } + vm::VirtualMachineEngine* mut_engine() { return engine_.Mutable(); } + void ControlSync(); + Maybe FindOrCreateThreadCtx(Symbol device, StreamRole stream_role); + Maybe CreateThreadCtx(Symbol device, StreamRole stream_role); + Maybe CreateStream(Symbol device, StreamRole stream_role); + + Maybe CreateStream(vm::ThreadCtx* thread_ctx, Symbol device, + StreamRole stream_role); Maybe RunInCurrentThread(vm::InstructionMsgList* instr_list); - bool vm_threads_closed_; - intrusive::shared_ptr vm_; + Maybe NotifyOrRunScheduler(); + + bool disable_vm_threads_; + bool scheduler_stopped_; + intrusive::shared_ptr engine_; + // for asynchronized execution + std::mutex worker_threads_mutex_; std::list> worker_threads_; + + // for creating vm::Stream and vm::ThreadCtx + std::recursive_mutex creating_stream_and_thread_ctx_mutex_; + HashMap devcie_type2non_independent_thread_ctx_; + HashMap, vm::ThreadCtx*> + devcie_type_stream_role_2independent_thread_ctx_; + HashMap, StreamRole>, intrusive::shared_ptr> + device_stream_role2local_dep_object_; + intrusive::shared_ptr transport_local_dep_object_; + SteadyVector unique_stream_id2vm_stream_; + std::thread schedule_thread_; Notifier pending_notifier_; }; diff --git a/oneflow/core/vm/virtual_machine_engine.cpp b/oneflow/core/vm/virtual_machine_engine.cpp index 05052ce654a..5d2a4b157df 100644 --- a/oneflow/core/vm/virtual_machine_engine.cpp +++ b/oneflow/core/vm/virtual_machine_engine.cpp @@ -14,21 +14,20 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/vm/virtual_machine_engine.h" -#include "oneflow/core/vm/vm_desc.h" #include "oneflow/core/vm/instruction_type.h" +#include "oneflow/core/vm/fuse_instruction_type.h" #include "oneflow/core/vm/fuse_phy_instr_operand.h" #include "oneflow/core/vm/barrier_phy_instr_operand.h" #include "oneflow/core/common/util.h" #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/common/cpp_attribute.h" #include "oneflow/core/framework/device.h" -#include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/platform/include/pthread_fork.h" #include "oneflow/core/profiler/profiler.h" #include "oneflow/core/common/cpp_attribute.h" #include "oneflow/core/common/global.h" +#include "oneflow/core/common/singleton_ptr.h" #include "oneflow/core/common/foreign_lock_helper.h" -#include namespace oneflow { namespace vm { @@ -80,16 +79,14 @@ namespace { bool FusableBetween(InstructionFuseType fuse_type, InstructionMsg* instr_msg, InstructionMsg* prev_instr_msg) { - if (unlikely(instr_msg->instr_type_id().instruction_type().fuse_type() != fuse_type)) { - return false; - } - auto* phy_instr_stream = instr_msg->phy_instr_stream(); - if (unlikely(phy_instr_stream == nullptr)) { return false; } + if (unlikely(instr_msg->instruction_type().fuse_type() != fuse_type)) { return false; } + auto* stream = instr_msg->mut_stream(); + if (unlikely(stream == nullptr)) { return false; } auto* sequential_dep = instr_msg->phy_instr_operand()->stream_sequential_dependence(); if (unlikely(sequential_dep == nullptr)) { return false; } if (unlikely(prev_instr_msg == nullptr)) { return true; } - if (unlikely(phy_instr_stream != prev_instr_msg->phy_instr_stream())) { return false; } + if (unlikely(stream != prev_instr_msg->mut_stream())) { return false; } if (unlikely(sequential_dep != prev_instr_msg->phy_instr_operand()->stream_sequential_dependence())) { return false; @@ -108,9 +105,8 @@ void VirtualMachineEngine::MakeAndAppendFusedInstruction( } auto* begin = fused_instr_msg_list.Begin(); auto phy_instr_operand = std::make_shared(std::move(fused_instr_msg_list)); - const auto* stream_tag = begin->phy_instr_stream()->stream_type().stream_tag(); auto instr_msg = intrusive::make_shared( - this, std::string(stream_tag) + ".Fuse", begin->phy_instr_parallel_desc(), phy_instr_operand); + begin->mut_stream(), SingletonPtr(), phy_instr_operand); pending_instr_msgs->EmplaceBack(std::move(instr_msg)); } @@ -190,18 +186,12 @@ void VirtualMachineEngine::ReleaseFinishedInstructions(const ScheduleCtx& schedu OF_PROFILER_RANGE_POP(); } -int64_t VirtualMachineEngine::this_machine_id() const { - CHECK_EQ(machine_id_range().size(), 1); - return machine_id_range().begin(); -} - void VirtualMachineEngine::MakeInstructions(InstructionMsg* instr_msg, /*out*/ InstructionList* new_instruction_list) { - const auto& instruction_type = instr_msg->instr_type_id().instruction_type(); - bool is_barrier_instruction = instruction_type.IsFrontSequential(); - Stream* stream = CHECK_NOTNULL(instr_msg->phy_instr_stream()); - const auto& pd = instr_msg->phy_instr_parallel_desc(); - intrusive::shared_ptr instr = stream->NewInstruction(instr_msg, pd); + const auto& instruction_type = instr_msg->instruction_type(); + bool is_barrier_instruction = instruction_type.IsBarrier(); + Stream* stream = CHECK_NOTNULL(instr_msg->mut_stream()); + intrusive::shared_ptr instr = stream->NewInstruction(instr_msg); LivelyInstructionListPushBack(instr.Mutable()); if (unlikely(is_barrier_instruction)) { mut_barrier_instruction_list()->PushBack(instr.Mutable()); @@ -324,58 +314,6 @@ void VirtualMachineEngine::DispatchInstruction(Instruction* instruction, } } -void VirtualMachineEngine::__Init__(const VmDesc& vm_desc) { - mut_vm_resource_desc()->CopyFrom(vm_desc.vm_resource_desc()); - CHECK_GT(vm_desc.machine_id_range().size(), 0); - *mut_machine_id_range() = vm_desc.machine_id_range(); - INTRUSIVE_UNSAFE_FOR_EACH_PTR(stream_desc, &vm_desc.stream_type2desc()) { - if (stream_desc->num_threads() == 0) { continue; } - auto stream_rt_desc = intrusive::make_shared(stream_desc); - mut_stream_type2stream_rt_desc()->Insert(stream_rt_desc.Mutable()); - BalancedSplitter bs(stream_desc->parallel_num(), stream_desc->num_threads()); - for (int64_t i = 0, rel_global_device_id = 0; i < stream_desc->num_threads(); ++i) { - auto thread_ctx = intrusive::make_shared(stream_rt_desc.Get()); - mut_thread_ctx_list()->PushBack(thread_ctx.Mutable()); - for (int j = bs.At(i).begin(); j < bs.At(i).end(); ++j, ++rel_global_device_id) { - StreamId stream_id; - stream_id.__Init__(&stream_desc->stream_type(), - this_start_global_device_id() + rel_global_device_id); - auto stream = intrusive::make_shared( - thread_ctx.Mutable(), stream_id, vm_resource_desc().max_device_num_per_machine()); - stream_rt_desc->add_stream(stream); - thread_ctx->mut_stream_list()->PushBack(stream.Mutable()); - } - } - } -} - -void VirtualMachineEngine::GetCachedInstrTypeIdAndPhyInstrStream(const std::string& instr_type_name, - int device_id, - InstrTypeId* instr_type_id, - Stream** stream) { - auto* cache = &instr_type_name2rt_instr_type_id_; - auto iter = cache->find(instr_type_name); - if (unlikely(iter == cache->end())) { - const auto& instr_type_id_val = LookupInstrTypeId(instr_type_name); - const auto* stream_type = &instr_type_id_val.stream_type(); - auto* stream_rt_desc = this->mut_stream_type2stream_rt_desc()->FindPtr(stream_type); - iter = cache->emplace(instr_type_name, RtInstrTypeId(instr_type_id_val, stream_rt_desc)).first; - } - instr_type_id->CopyFrom(iter->second.instr_type_id()); - *stream = iter->second.GetStream(device_id); -} - -void VirtualMachineEngine::GetInstrTypeIdAndSoleStream(const std::string& instr_type_name, - InstrTypeId* instr_type_id, - Stream** stream) { - instr_type_id->CopyFrom(LookupInstrTypeId(instr_type_name)); - const auto* stream_type = &instr_type_id->stream_type(); - auto* stream_rt_desc = this->mut_stream_type2stream_rt_desc()->FindPtr(stream_type); - *stream = stream_rt_desc->GetSoleStream(); -} - -int64_t InstructionMaxRunningSeconds() { return 60 * 5; } - // Returns true if old pending_instruction_list is empty Maybe VirtualMachineEngine::Receive(InstructionMsgList* compute_instr_msg_list) { OF_PROFILER_RANGE_GUARD("vm:Receive"); @@ -387,13 +325,6 @@ Maybe VirtualMachineEngine::Receive(InstructionMsgList* compute_instr_msg_ return old_list_empty; } -Maybe VirtualMachineEngine::Receive( - intrusive::shared_ptr&& compute_instr_msg) { - InstructionMsgList instr_msg_list; - instr_msg_list.EmplaceBack(std::move(compute_instr_msg)); - return Receive(&instr_msg_list); -} - bool VirtualMachineEngine::OnSchedulerThread(const StreamType& stream_type) { return stream_type.OnSchedulerThread() || pthread_fork::IsForkedSubProcess(); } @@ -456,7 +387,7 @@ bool VirtualMachineEngine::OnSchedulerThread(const StreamType& stream_type) { // instructions are scarcely received by vm, there is no need for vm to run // VirtualMachineEngine::TryRunBarrierInstruction every time VirtualMachineEngine::Schedule run. On // the other hand, `barrier_instruction_hook_.size() == 0` is more lightweight than -// `lively_instruction_list_.Begin()?->instr_msg().instr_type_id().instruction_type().IsFrontSequential()` +// `lively_instruction_list_.Begin()?->instr_msg().instruction_type().IsBarrier()` // void VirtualMachineEngine::TryRunBarrierInstruction(const ScheduleCtx& schedule_ctx) { auto* sequnential_instruction = mut_barrier_instruction_list()->Begin(); @@ -465,10 +396,9 @@ void VirtualMachineEngine::TryRunBarrierInstruction(const ScheduleCtx& schedule_ // All instructions before `sequnential_instruction` are handled now, it's time to handle // `sequnential_instruction`. OF_PROFILER_RANGE_GUARD("RunBarrierInstruction"); - const auto& instr_type_id = sequnential_instruction->instr_msg().instr_type_id(); - const auto& instruction_type = instr_type_id.instruction_type(); - CHECK(instruction_type.IsFrontSequential()); - const StreamType& stream_type = instr_type_id.stream_type(); + const auto& instruction_type = sequnential_instruction->instr_msg().instruction_type(); + CHECK(instruction_type.IsBarrier()); + const StreamType& stream_type = sequnential_instruction->instr_msg().stream().stream_type(); CHECK(OnSchedulerThread(stream_type)); stream_type.Run(sequnential_instruction); mut_barrier_instruction_list()->Erase(sequnential_instruction); diff --git a/oneflow/core/vm/virtual_machine_engine.h b/oneflow/core/vm/virtual_machine_engine.h index 000dc38ab49..4b7df3a182b 100644 --- a/oneflow/core/vm/virtual_machine_engine.h +++ b/oneflow/core/vm/virtual_machine_engine.h @@ -20,13 +20,10 @@ limitations under the License. #include "oneflow/core/common/maybe.h" #include "oneflow/core/vm/instruction.h" #include "oneflow/core/vm/stream.h" -#include "oneflow/core/vm/stream_runtime_desc.h" -#include "oneflow/core/vm/runtime_instr_type_id.h" #include "oneflow/core/vm/thread_ctx.h" #include "oneflow/core/vm/vm_object.h" #include "oneflow/core/vm/vm_resource_desc.h" #include "oneflow/core/common/range.h" -#include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/intrusive/mutexed_list.h" #include "oneflow/core/intrusive/object_pool.h" #include "oneflow/core/vm/probe.h" @@ -45,7 +42,6 @@ class ScheduleCtx { virtual void OnWorkerLoadPending(vm::ThreadCtx* thread_ctx) const = 0; }; -class VmDesc; class VirtualMachineEngine final : public intrusive::Base { public: // types @@ -58,16 +54,8 @@ class VirtualMachineEngine final : public intrusive::Base { intrusive::List; using InstructionMsgMutexedList = intrusive::MutexedList; - using StreamType2StreamRtDesc = - intrusive::SkipList; // Getters - const VmResourceDesc& vm_resource_desc() const { - if (vm_resource_desc_) { return vm_resource_desc_.Get(); } - static const auto default_val = intrusive::make_shared(); - return default_val.Get(); - } - const Range& machine_id_range() const { return machine_id_range_; } std::size_t flying_instruction_cnt() const { return pending_msg_list().thread_unsafe_size() + local_pending_msg_list().size() + (total_inserted_instruction_cnt() - total_erased_instruction_cnt()); @@ -83,46 +71,22 @@ class VirtualMachineEngine final : public intrusive::Base { } const InstructionMsgMutexedList& pending_msg_list() const { return pending_msg_list_; } const InstructionMsgList& local_pending_msg_list() const { return local_pending_msg_list_; } - const StreamType2StreamRtDesc& stream_type2stream_rt_desc() const { - return stream_type2stream_rt_desc_; - } // Setters - VmResourceDesc* mut_vm_resource_desc() { - if (!vm_resource_desc_) { vm_resource_desc_ = intrusive::make_shared(); } - return vm_resource_desc_.Mutable(); - } - Range* mut_machine_id_range() { return &machine_id_range_; } ActiveStreamList* mut_active_stream_list() { return &active_stream_list_; } ThreadCtxList* mut_thread_ctx_list() { return &thread_ctx_list_; } LivelyInstructionList* mut_lively_instruction_list() { return &lively_instruction_list_; } BarrierInstructionList* mut_barrier_instruction_list() { return &barrier_instruction_list_; } InstructionMsgMutexedList* mut_pending_msg_list() { return &pending_msg_list_; } InstructionMsgList* mut_local_pending_msg_list() { return &local_pending_msg_list_; } - StreamType2StreamRtDesc* mut_stream_type2stream_rt_desc() { return &stream_type2stream_rt_desc_; } - // methods - void __Init__(const VmDesc& vm_desc); - // Returns true if old pending_instruction_list is empty - Maybe Receive(InstructionMsgList* instr_list); // Returns true if old pending_instruction_list is empty - Maybe Receive(intrusive::shared_ptr&& instruction_msg); + Maybe Receive(InstructionMsgList* compute_instr_msg_list); void Schedule(const ScheduleCtx& schedule_ctx); void Callback(); bool SchedulerThreadUnsafeEmpty() const; bool SchedulerEmpty() const; std::string GetLivelyInstructionListDebugString(int64_t debug_cnt); - int64_t this_machine_id() const; - int64_t this_start_global_device_id() const { - return this_machine_id() * vm_resource_desc().max_device_num_per_machine(); - } - - void GetCachedInstrTypeIdAndPhyInstrStream(const std::string& instr_type_name, int device_id, - InstrTypeId* instr_type_id, Stream** stream); - - void GetInstrTypeIdAndSoleStream(const std::string& instr_type_name, InstrTypeId* instr_type_id, - Stream** stream); - private: using ReadyInstructionList = intrusive::List; @@ -164,11 +128,8 @@ class VirtualMachineEngine final : public intrusive::Base { VirtualMachineEngine() : intrusive_ref_(), - vm_resource_desc_(), - machine_id_range_(), active_stream_list_(), thread_ctx_list_(), - stream_type2stream_rt_desc_(), pending_msg_mutex_(), pending_msg_list_(&pending_msg_mutex_), local_pending_msg_list_(), @@ -181,14 +142,10 @@ class VirtualMachineEngine final : public intrusive::Base { local_probe_list_(), barrier_instruction_list_() {} intrusive::Ref intrusive_ref_; - // fields - intrusive::shared_ptr vm_resource_desc_; - Range machine_id_range_; // lists or maps // Do not change the order of the following fields ActiveStreamList active_stream_list_; ThreadCtxList thread_ctx_list_; - StreamType2StreamRtDesc stream_type2stream_rt_desc_; std::mutex pending_msg_mutex_; InstructionMsgMutexedList pending_msg_list_; // local_pending_msg_list_ should be consider as the cache of pending_msg_list_. @@ -204,7 +161,6 @@ class VirtualMachineEngine final : public intrusive::Base { intrusive::List local_probe_list_; BarrierInstructionList barrier_instruction_list_; - std::map instr_type_name2rt_instr_type_id_; DependenceAccess::object_pool_type access_pool_; InstructionEdge::object_pool_type instruction_edge_pool_; }; diff --git a/oneflow/core/vm/virtual_machine_scope.cpp b/oneflow/core/vm/virtual_machine_scope.cpp index d326c4cee5c..0f6233a194a 100644 --- a/oneflow/core/vm/virtual_machine_scope.cpp +++ b/oneflow/core/vm/virtual_machine_scope.cpp @@ -22,7 +22,7 @@ namespace oneflow { namespace vm { VirtualMachineScope::VirtualMachineScope(const Resource& resource) { - Global::New(resource, GlobalProcessCtx::Rank()); + Global::New(); } VirtualMachineScope::~VirtualMachineScope() { Global::Delete(); } diff --git a/oneflow/core/vm/vm_desc.cpp b/oneflow/core/vm/vm_desc.cpp deleted file mode 100644 index f106d935b4a..00000000000 --- a/oneflow/core/vm/vm_desc.cpp +++ /dev/null @@ -1,70 +0,0 @@ -/* -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -#include "oneflow/core/vm/vm_desc.h" -#include "oneflow/core/vm/stream_desc.h" -#include "oneflow/core/vm/stream_type.h" -#include "oneflow/core/vm/instruction_type.h" -#include "oneflow/core/common/util.h" - -namespace oneflow { -namespace vm { - -namespace { - -void SetMachineIdRange(Range* range, int64_t machine_num, int64_t this_machine_id) { - *range = Range(this_machine_id, this_machine_id + 1); -} - -intrusive::shared_ptr MakeVmDesc( - const Resource& resource, int64_t this_machine_id, - const std::function&)>& ForEachInstrTypeId) { - std::set stream_types; - ForEachInstrTypeId( - [&](const InstrTypeId& instr_type_id) { stream_types.insert(&instr_type_id.stream_type()); }); - auto vm_desc = - intrusive::make_shared(intrusive::make_shared(resource).Get()); - SetMachineIdRange(vm_desc->mut_machine_id_range(), resource.machine_num(), this_machine_id); - int cnt = 0; - for (const auto* stream_type : stream_types) { - auto stream_desc = stream_type->MakeStreamDesc(resource, this_machine_id); - if (stream_desc) { - ++cnt; - CHECK(vm_desc->mut_stream_type2desc()->Insert(stream_desc.Mutable()).second); - } - } - CHECK_EQ(vm_desc->stream_type2desc().size(), cnt); - return vm_desc; -} - -} // namespace - -intrusive::shared_ptr MakeVmDesc(const Resource& resource, int64_t this_machine_id) { - return MakeVmDesc(resource, this_machine_id, &ForEachInstrTypeId); -} - -intrusive::shared_ptr MakeVmDesc(const Resource& resource, int64_t this_machine_id, - const std::set& instr_type_names) { - const auto& ForEachInstrTypeId = [&](const std::function& Handler) { - for (const auto& instr_type_name : instr_type_names) { - Handler(LookupInstrTypeId(instr_type_name)); - Handler(LookupInstrTypeId(std::string("Infer-") + instr_type_name)); - } - }; - return MakeVmDesc(resource, this_machine_id, ForEachInstrTypeId); -} - -} // namespace vm -} // namespace oneflow diff --git a/oneflow/core/vm/vm_desc.h b/oneflow/core/vm/vm_desc.h deleted file mode 100644 index b28d29db00c..00000000000 --- a/oneflow/core/vm/vm_desc.h +++ /dev/null @@ -1,74 +0,0 @@ -/* -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -#ifndef ONEFLOW_CORE_VM_MEM_ZONE_TYPE_DESC__H_ -#define ONEFLOW_CORE_VM_MEM_ZONE_TYPE_DESC__H_ - -#include "oneflow/core/vm/stream_desc.h" -#include "oneflow/core/vm/virtual_machine_engine.h" -#include "oneflow/core/vm/vm_resource_desc.h" -#include "oneflow/core/common/range.h" - -namespace oneflow { -namespace vm { - -class VmDesc final : public intrusive::Base { - public: - // types - using StreamType2StreamDesc = intrusive::SkipList; - // Getters - const VmResourceDesc& vm_resource_desc() const { - if (vm_resource_desc_) { return vm_resource_desc_.Get(); } - static const auto default_val = intrusive::make_shared(); - return default_val.Get(); - } - const Range& machine_id_range() const { return machine_id_range_; } - const StreamType2StreamDesc& stream_type2desc() const { return stream_type2desc_; } - // Setters - VmResourceDesc* mut_vm_resource_desc() { - if (!vm_resource_desc_) { vm_resource_desc_ = intrusive::make_shared(); } - return vm_resource_desc_.Mutable(); - } - Range* mut_machine_id_range() { return &machine_id_range_; } - StreamType2StreamDesc* mut_stream_type2desc() { return &stream_type2desc_; } - - // methods - void __Init__(const VmResourceDesc& vm_resource_desc) { __Init__(vm_resource_desc, Range(0, 1)); } - void __Init__(const VmResourceDesc& vm_resource_desc, const Range& machine_id_range) { - mut_vm_resource_desc()->CopyFrom(vm_resource_desc); - *mut_machine_id_range() = machine_id_range; - } - - private: - friend class intrusive::Ref; - intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; } - - VmDesc() : intrusive_ref_(), vm_resource_desc_(), machine_id_range_(), stream_type2desc_() {} - intrusive::Ref intrusive_ref_; - // fields - intrusive::shared_ptr vm_resource_desc_; - Range machine_id_range_; - // maps - StreamType2StreamDesc stream_type2desc_; -}; - -intrusive::shared_ptr MakeVmDesc(const Resource& resource, int64_t this_machine_id); -intrusive::shared_ptr MakeVmDesc(const Resource& resource, int64_t this_machine_id, - const std::set& instr_type_names); - -} // namespace vm -} // namespace oneflow - -#endif // ONEFLOW_CORE_VM_MEM_ZONE_TYPE_DESC__H_ diff --git a/oneflow/core/vm/vm_object.h b/oneflow/core/vm/vm_object.h index cfc6b69a784..fae0c74bf38 100644 --- a/oneflow/core/vm/vm_object.h +++ b/oneflow/core/vm/vm_object.h @@ -20,9 +20,6 @@ limitations under the License. #include "oneflow/core/intrusive/flat_msg.h" #include "oneflow/core/intrusive/intrusive.h" #include "oneflow/core/intrusive/object_pool.h" -#include "oneflow/core/vm/id_util.h" -#include "oneflow/core/vm/stream_desc.h" -#include "oneflow/core/job/parallel_desc.h" namespace oneflow { diff --git a/oneflow/core/vm/vm_util.cpp b/oneflow/core/vm/vm_util.cpp index 3a39a93256c..d5ce990e0e6 100644 --- a/oneflow/core/vm/vm_util.cpp +++ b/oneflow/core/vm/vm_util.cpp @@ -20,7 +20,6 @@ limitations under the License. #include "oneflow/core/job/cluster_instruction.h" #include "oneflow/core/vm/vm_util.h" #include "oneflow/core/vm/virtual_machine.h" -#include "oneflow/core/vm/instruction.pb.h" #include "oneflow/core/vm/stream_type.h" #include "oneflow/core/vm/instruction_type.h" #include "oneflow/core/framework/instructions_builder.h" @@ -40,8 +39,8 @@ Maybe Run(vm::InstructionMsgList* instr_msg_list) { Maybe ClusterSync() { auto bc = std::make_shared(1); JUST(PhysicalRun([bc](InstructionsBuilder* builder) -> Maybe { - JUST(builder->ComputeGlobalFrontSeqBarrier()); - JUST(builder->ComputeRankFrontSeqCallback([bc]() { bc->Decrease(); })); + JUST(builder->GlobalSync()); + JUST(builder->Barrier([bc]() { bc->Decrease(); })); return Maybe::Ok(); })); JUST(bc->WaitUntilCntEqualZero(VirtualMachine::GetPredicatorNoMoreInstructionsFinished())); @@ -51,7 +50,7 @@ Maybe ClusterSync() { Maybe CurrentRankSync() { auto bc = std::make_shared(1); JUST(PhysicalRun([bc](InstructionsBuilder* builder) -> Maybe { - JUST(builder->ComputeRankFrontSeqCallback([bc]() { bc->Decrease(); })); + JUST(builder->Barrier([bc]() { bc->Decrease(); })); return Maybe::Ok(); })); JUST(bc->WaitUntilCntEqualZero(VirtualMachine::GetPredicatorNoMoreInstructionsFinished())); diff --git a/oneflow/user/kernels/stateful_local_opkernel.cpp b/oneflow/user/kernels/stateful_opkernel.cpp similarity index 96% rename from oneflow/user/kernels/stateful_local_opkernel.cpp rename to oneflow/user/kernels/stateful_opkernel.cpp index 629a795240a..6afbc1bbd07 100644 --- a/oneflow/user/kernels/stateful_local_opkernel.cpp +++ b/oneflow/user/kernels/stateful_opkernel.cpp @@ -13,7 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "oneflow/user/kernels/stateful_local_opkernel.h" +#include "oneflow/user/kernels/stateful_opkernel.h" #include "oneflow/core/framework/attr_value_accessor.h" #include "oneflow/core/framework/user_op_conf.h" #include "oneflow/core/framework/user_op_registry_manager.h" @@ -370,12 +370,12 @@ Maybe InitTensorTupleIndexes4Bns(const std::shared_ptr return Maybe::Ok(); } -/* static */ Maybe StatefulLocalOpKernel::New( +/* static */ Maybe StatefulOpKernel::New( const std::shared_ptr& op_conf, const Symbol& stream, const AttrMap& base_attrs, const std::shared_ptr& parallel_desc, const std::shared_ptr& input_arg_tuple, const std::shared_ptr& output_arg_tuple) { - auto opkernel = std::shared_ptr(new StatefulLocalOpKernel()); + auto opkernel = std::shared_ptr(new StatefulOpKernel()); opkernel->op_conf_ = op_conf; opkernel->user_op_conf_.reset(new user_op::UserOpConfWrapper(op_conf)); opkernel->stream_ = stream; @@ -419,9 +419,9 @@ Maybe InitTensorTupleIndexes4Bns(const std::shared_ptr return opkernel; } -StatefulLocalOpKernel::~StatefulLocalOpKernel() = default; +StatefulOpKernel::~StatefulOpKernel() = default; -Maybe StatefulLocalOpKernel::ChooseOpKernel( +Maybe StatefulOpKernel::ChooseOpKernel( const user_op::OpKernel** user_opkernel, bool* need_temp_storage, const AttrMap& attrs, EagerBlobObjectListRawPtr inputs, EagerBlobObjectListRawPtr outputs, ConsistentTensorInferResultRawPtr consistent_tensor_infer_result) { @@ -463,7 +463,7 @@ Maybe StatefulLocalOpKernel::ChooseOpKernel( return Maybe::Ok(); } -void StatefulLocalOpKernel::TryInitOpKernelStateAndCache( +void StatefulOpKernel::TryInitOpKernelStateAndCache( const user_op::OpKernel* op_kernel, DeviceCtx* device_ctx, EagerBlobObjectListRawPtr inputs, EagerBlobObjectListRawPtr outputs, ConsistentTensorInferResultRawPtr consistent_tensor_infer_result, @@ -490,24 +490,20 @@ void StatefulLocalOpKernel::TryInitOpKernelStateAndCache( } } -const user_op::InferTmpSizeFn& StatefulLocalOpKernel::GetInferTmpSizeFn( +const user_op::InferTmpSizeFn& StatefulOpKernel::GetInferTmpSizeFn( const user_op::OpKernel* op_kernel) const { return *infer_tmp_size_fn_map_.at(op_kernel); } -vm::EagerBlobObject* StatefulLocalOpKernel::mut_temp_blob_object() { - return tmp_blob_object_.get(); -} +vm::EagerBlobObject* StatefulOpKernel::mut_temp_blob_object() { return tmp_blob_object_.get(); } -user_op::TensorDescInferFn StatefulLocalOpKernel::TensorDescInferFn() const { +user_op::TensorDescInferFn StatefulOpKernel::TensorDescInferFn() const { return tensor_desc_infer_fn_; } -user_op::DataTypeInferFn StatefulLocalOpKernel::DataTypeInferFn() const { - return data_type_infer_fn_; -} +user_op::DataTypeInferFn StatefulOpKernel::DataTypeInferFn() const { return data_type_infer_fn_; } -LocalUserKernelComputeContext* StatefulLocalOpKernel::UpdateComputeContext( +LocalUserKernelComputeContext* StatefulOpKernel::UpdateComputeContext( EagerBlobObjectListRawPtr inputs, EagerBlobObjectListRawPtr outputs, ConsistentTensorInferResultRawPtr consistent_tensor_infer_result, DeviceCtx* device_ctx) { compute_ctx_->Update(inputs, outputs, consistent_tensor_infer_result, device_ctx); diff --git a/oneflow/user/kernels/stateful_local_opkernel.h b/oneflow/user/kernels/stateful_opkernel.h similarity index 95% rename from oneflow/user/kernels/stateful_local_opkernel.h rename to oneflow/user/kernels/stateful_opkernel.h index 750b02b7f46..fba5fb4e7d8 100644 --- a/oneflow/user/kernels/stateful_local_opkernel.h +++ b/oneflow/user/kernels/stateful_opkernel.h @@ -13,8 +13,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifndef ONEFLOW_USER_KERNELS_STATEFUL_LOCAL_OPKERNEL_H_ -#define ONEFLOW_USER_KERNELS_STATEFUL_LOCAL_OPKERNEL_H_ +#ifndef ONEFLOW_USER_KERNELS_STATEFUL_OPKERNEL_H_ +#define ONEFLOW_USER_KERNELS_STATEFUL_OPKERNEL_H_ #include "oneflow/core/eager/eager_blob_object.h" #include "oneflow/core/framework/tensor_meta.h" @@ -30,7 +30,7 @@ namespace oneflow { class AttrMap; namespace vm { -struct LocalCallOpKernelUtil; +struct OpCallInstructionUtil; } // namespace vm namespace one { @@ -382,15 +382,15 @@ class LocalUserKernelComputeContext final : public user_op::KernelComputeContext LocalUserKernelBaseContext base_ctx_; }; -class StatefulLocalOpKernel final { +class StatefulOpKernel final { public: - OF_DISALLOW_COPY_AND_MOVE(StatefulLocalOpKernel); - static Maybe New(const std::shared_ptr& op_conf, - const Symbol& stream, const AttrMap& base_attrs, - const std::shared_ptr& parallel_desc, - const std::shared_ptr& input_arg_tuple, - const std::shared_ptr& output_arg_tuple); - ~StatefulLocalOpKernel(); + OF_DISALLOW_COPY_AND_MOVE(StatefulOpKernel); + static Maybe New(const std::shared_ptr& op_conf, + const Symbol& stream, const AttrMap& base_attrs, + const std::shared_ptr& parallel_desc, + const std::shared_ptr& input_arg_tuple, + const std::shared_ptr& output_arg_tuple); + ~StatefulOpKernel(); const Symbol& stream() const { return stream_; } const std::shared_ptr& mem_case() const { return stream_->device()->mem_case(); } const std::string& op_type_name() const { return op_conf_->user_conf().op_type_name(); } @@ -429,8 +429,8 @@ class StatefulLocalOpKernel final { const OperatorConf& op_conf() const { return *op_conf_; } private: - friend struct vm::LocalCallOpKernelUtil; - StatefulLocalOpKernel() = default; + friend struct vm::OpCallInstructionUtil; + StatefulOpKernel() = default; LocalUserKernelComputeContext* UpdateComputeContext( EagerBlobObjectListRawPtr inputs, EagerBlobObjectListRawPtr outputs, ConsistentTensorInferResultRawPtr consistent_tensor_infer_result, DeviceCtx* device_ctx); @@ -487,4 +487,4 @@ class StatefulLocalOpKernel final { } // namespace oneflow -#endif // ONEFLOW_USER_KERNELS_STATEFUL_LOCAL_OPKERNEL_H_ +#endif // ONEFLOW_USER_KERNELS_STATEFUL_OPKERNEL_H_ diff --git a/python/oneflow/nn/graph/block.py b/python/oneflow/nn/graph/block.py index 407542bd41a..fa38031ebb1 100644 --- a/python/oneflow/nn/graph/block.py +++ b/python/oneflow/nn/graph/block.py @@ -20,7 +20,7 @@ import oneflow._C import oneflow._oneflow_internal -import oneflow.framework.graph_build_util as graph_build_util +from oneflow.framework import graph_build_util from oneflow.env import get_rank from oneflow.framework.tensor import Tensor, TensorTuple from oneflow.nn.module import Module diff --git a/python/oneflow/test/exceptions/test_device.py b/python/oneflow/test/exceptions/test_device.py index 4aac53368a0..4a1453c3448 100644 --- a/python/oneflow/test/exceptions/test_device.py +++ b/python/oneflow/test/exceptions/test_device.py @@ -39,10 +39,7 @@ def test_device_index(test_case): # device = flow.device("cuda:1000") # flow.Tensor(2, 3).to(device=device) # test_case.assertTrue("CUDA error: invalid device ordinal" in str(exp.exception)) - - with test_case.assertRaises(RuntimeError) as exp: - device = flow.device("cpu:1000") - flow.Tensor(2, 3).to(device=device) + pass if __name__ == "__main__": diff --git a/python/oneflow/test/modules/test_consistent_tensordot.py b/python/oneflow/test/modules/test_consistent_tensordot.py index 517d8ad1c38..cf0abaadd2a 100644 --- a/python/oneflow/test/modules/test_consistent_tensordot.py +++ b/python/oneflow/test/modules/test_consistent_tensordot.py @@ -20,7 +20,7 @@ from oneflow.test_utils.automated_test_util import * -@autotest(n=1, check_graph=False) +@autotest(n=1, check_graph=False, atol=1e-3) def _test_global_tensordot_against_pytorch(test_case, ndim, placement, sbp): k = random(1, 2) * 8 tensordot_dim = random(0, ndim + 1).to(int) diff --git a/python/oneflow/test_utils/automated_test_util/profiler.py b/python/oneflow/test_utils/automated_test_util/profiler.py index 8e6551e9d9d..9d7ff2a24a3 100644 --- a/python/oneflow/test_utils/automated_test_util/profiler.py +++ b/python/oneflow/test_utils/automated_test_util/profiler.py @@ -20,7 +20,9 @@ import torch import oneflow as flow import oneflow.support.env_var_util -import oneflow.test_utils.automated_test_util.torch_flow_dual_object as dual_object_module +from oneflow.test_utils.automated_test_util import ( + torch_flow_dual_object as dual_object_module, +) __all__ = ["profile", "set_profiler_hook", "profile_dual_object"] diff --git a/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py b/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py index afc05af4b9e..b0254129ca6 100644 --- a/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py +++ b/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py @@ -23,7 +23,7 @@ import numpy as np import oneflow as flow -import oneflow.test_utils.automated_test_util.profiler as auto_profiler +from oneflow.test_utils.automated_test_util import profiler as auto_profiler flow.backends.cudnn.deterministic = True