diff --git a/oneflow/core/common/constant.h b/oneflow/core/common/constant.h index 3f8b331bdb4..7760e161128 100644 --- a/oneflow/core/common/constant.h +++ b/oneflow/core/common/constant.h @@ -24,6 +24,7 @@ static const int64_t kInvalidSessionId = -1; static const std::string kNoPassTag = ""; static const std::string kMainOp = "main_op"; static const int64_t kMaxSplitAxis = 6; +constexpr size_t kMaxNumDims = 8; static const std::string kAsymmetricCodeErrorMsg = "Maybe executing different code in different ranks, please check if the code is branched and " "operates on the global tensor."; diff --git a/oneflow/core/common/env_var/eager.h b/oneflow/core/common/env_var/eager.h new file mode 100644 index 00000000000..ad7108ceb2d --- /dev/null +++ b/oneflow/core/common/env_var/eager.h @@ -0,0 +1,28 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_CORE_COMMON_ENV_VAR_EAGER_H_ +#define ONEFLOW_CORE_COMMON_ENV_VAR_EAGER_H_ + +#include "oneflow/core/common/env_var/env_var.h" + +namespace oneflow { + +// NOTE: use env variable 'ONEFLOW_EAGER_ENABLE_LOCAL_INFER_CACHE' indicate whether the +// use infer cache in naive local op interpret. +DEFINE_THREAD_LOCAL_ENV_BOOL(ONEFLOW_EAGER_ENABLE_LOCAL_INFER_CACHE, true); + +} // namespace oneflow +#endif // ONEFLOW_CORE_COMMON_ENV_VAR_EAGER_H_ diff --git a/oneflow/core/common/stride.cpp b/oneflow/core/common/stride.cpp index 38552a832f9..ab130076065 100644 --- a/oneflow/core/common/stride.cpp +++ b/oneflow/core/common/stride.cpp @@ -15,6 +15,7 @@ limitations under the License. */ #include "oneflow/core/common/stride.h" +#include "oneflow/core/common/constant.h" #include "oneflow/core/common/protobuf.h" #include "oneflow/core/common/cplusplus_17.h" @@ -29,7 +30,7 @@ Stride::Stride(const Shape& shape) { std::multiplies<>{}); } else if (ndim > 0 && shape.elem_cnt() == 0) { // 0-size shape - std::vector tmp_shape(ndim); + small_vector tmp_shape(ndim); for (int64_t i = 0; i < ndim; ++i) { tmp_shape[i] = shape.At(i) > 0 ? shape.At(i) : 1; } std::exclusive_scan(tmp_shape.rbegin(), tmp_shape.rend(), rbegin(), (int64_t)1, std::multiplies<>{}); diff --git a/oneflow/core/framework/local_tensor_infer_cache.cpp b/oneflow/core/framework/local_tensor_infer_cache.cpp new file mode 100644 index 00000000000..e4c246d5837 --- /dev/null +++ b/oneflow/core/framework/local_tensor_infer_cache.cpp @@ -0,0 +1,209 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/local_tensor_infer_cache.h" +#include "oneflow/core/framework/tensor_tuple.h" +#include "oneflow/core/framework/tensor.h" +#include "oneflow/core/operator/operator.h" +#include "oneflow/core/framework/op_expr.h" +#include "oneflow/core/common/container_util.h" +#include "oneflow/core/common/env_var/eager.h" +#include "oneflow/core/framework/infer_util.h" + +namespace oneflow { +namespace one { + +namespace { + +Maybe CheckIsDeviceSupportedByOp(const Device& device, const std::string& op_type_name) { + if (IsCpuOnly(op_type_name)) { CHECK_EQ_OR_RETURN(device.type(), "cpu"); } // NOLINT + return Maybe::Ok(); +} + +Maybe CheckInputDeviceIdentical(const LocalTensorMetaInferArgs& infer_args, + Symbol default_device) { + for (int i = 0; i < infer_args.input_local_tensor_metas().size(); ++i) { + CHECK_OR_RETURN(default_device + == JUST(VectorAt(infer_args.input_local_tensor_metas(), i))->device()) + << Error::RuntimeError() + << "Expected all tensors to be on the same device, but found " + "at least two devices, " + << default_device->ToString() << " (positional 0) and " + << JUST(VectorAt(infer_args.input_local_tensor_metas(), i))->device()->ToString() + << " (positional " << i << ")!"; + } + return Maybe::Ok(); +} + +class UserOpExprDeviceAndStreamInferContext final : public user_op::DeviceAndStreamInferContext { + public: + UserOpExprDeviceAndStreamInferContext(const UserOpExpr* user_op_expr, + const LocalTensorMetaInferArgs& infer_args, + OpArgsVector* output_tensor_metas) + : user_op_expr_(user_op_expr), + composed_attrs_(infer_args.attrs(), user_op_expr->base_attrs()), + infer_args_(infer_args), + output_tensor_metas_(output_tensor_metas) {} + + const std::vector>& inputs() const override { + return user_op_expr_->indexed_input_pairs(); + } + + const std::vector>& outputs() const override { + return user_op_expr_->indexed_output_pairs(); + } + + Symbol* OutputTensorDevice4ArgNameAndIndex(const std::string& name, + int64_t index) override { + const auto& arg_tuple = *user_op_expr_->output_arg_tuple(); + int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); + CHECK_GE(tuple_index, 0); + CHECK_LT(tuple_index, user_op_expr_->output_size()); + return output_tensor_metas_->at(tuple_index).mut_device(); + } + + Symbol InputTensorDevice4ArgNameAndIndex(const std::string& name, + int64_t index) const override { + const auto& arg_tuple = *user_op_expr_->input_arg_tuple(); + int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); + CHECK_GE(tuple_index, 0); + CHECK_LT(tuple_index, user_op_expr_->input_size()); + return infer_args_.input_local_tensor_metas().at(tuple_index)->device(); + } + + private: + const std::shared_ptr& Attr4Name( + const std::string& attr_name) const override { + return composed_attrs_.Attr4Name(attr_name); + } + const UserOpExpr* user_op_expr_; + const ComposedAttrMap composed_attrs_; + const LocalTensorMetaInferArgs& infer_args_; + OpArgsVector* output_tensor_metas_; +}; + +Maybe> InferDeviceAndStream(const UserOpExpr& user_op_expr, + const Symbol& default_device, + const LocalTensorMetaInferArgs& infer_args, + OpArgsVector* output_tensor_metas) { + Symbol stream; + if (!user_op_expr.has_device_and_stream_infer_fn()) { + stream = JUST(GetDefaultStreamByDevice(default_device)); + for (int i = 0; i < user_op_expr.output_size(); i++) { + auto& tensor_meta = output_tensor_metas->at(i); + *tensor_meta.mut_device() = default_device; + } + } else { + if (!user_op_expr.device_and_stream_infer_fn()) { + Symbol device = infer_args.input_local_tensor_metas().at(0)->device(); + stream = JUST(GetDefaultStreamByDevice(device)); + } else { + UserOpExprDeviceAndStreamInferContext device_and_stream_ctx(&user_op_expr, infer_args, + output_tensor_metas); + stream = JUST(user_op_expr.device_and_stream_infer_fn()(&device_and_stream_ctx)); + } + } + return stream; +} + +} // namespace + +size_t LocalTensorMetaInferArgs::hash_value() const { + size_t hash_value = std::hash()(attrs_); + HashCombine(&hash_value, std::hash>()(default_device_)); + const auto& tensor_meta_hash_functor = std::hash>(); + for (const auto& tensor_meta : input_local_tensor_metas_) { + HashCombine(&hash_value, tensor_meta_hash_functor(tensor_meta)); + } + return hash_value; +} + +bool LocalTensorMetaInferArgs::operator==(const LocalTensorMetaInferArgs& other) const { + return this->attrs_ == other.attrs_ && this->default_device_ == other.default_device_ + && this->input_local_tensor_metas_ == other.input_local_tensor_metas_; +} + +Maybe LocalTensorMetaInferArgs::Init(const AttrMap& attrs, Symbol default_device, + const TensorTuple& input_tensors) { + this->attrs_ = attrs; + this->default_device_ = default_device; + this->input_local_tensor_metas_.resize(input_tensors.size()); + JUST(this->InitInputLocalTensorMetas(input_tensors)); + return Maybe::Ok(); +} + +Maybe LocalTensorMetaInferArgs::InitInputLocalTensorMetas(const TensorTuple& input_tensors) { + for (int i = 0; i < input_tensors.size(); ++i) { + LocalTensorMeta* local_tensor_meta = + dynamic_cast(input_tensors.at(i)->mut_tensor_meta()); + CHECK_NOTNULL_OR_RETURN(local_tensor_meta); // NOLINT + input_local_tensor_metas_.at(i) = SymbolOf(*local_tensor_meta); + } + return Maybe::Ok(); +} + +/* static */ Maybe LocalTensorInferCache::Infer( + const UserOpExpr& user_op_expr, const LocalTensorMetaInferArgs& infer_args) { + const auto& default_device = infer_args.default_device(); + JUST(CheckInputDeviceIdentical(infer_args, default_device)); + JUST(CheckIsDeviceSupportedByOp(*default_device, user_op_expr.op_type_name())); + + auto result = std::make_unique(user_op_expr.output_size()); + + OpArgsVector output_mut_metas(user_op_expr.output_size()); + // Infer devices + Symbol stream = + JUST(InferDeviceAndStream(user_op_expr, default_device, infer_args, &output_mut_metas)); + result->set_stream(stream); + + { + const auto& GetInputTensorMeta = [&](int32_t i) -> const TensorMeta* { + return infer_args.input_local_tensor_metas().at(i).shared_from_symbol().get(); + }; + JUST(user_op_expr.InferPhysicalTensorDesc( + infer_args.attrs(), stream->device()->type(), GetInputTensorMeta, + [&](int32_t i) -> TensorMeta* { return &output_mut_metas.at(i); })); + } + + auto* mut_output_tensor_metas = result->mut_output_tensor_metas(); + for (int32_t i = 0; i < user_op_expr.output_size(); ++i) { + if (!JUST(user_op_expr.SupportNonContiguous())) { + std::shared_ptr stride(new Stride(output_mut_metas.at(i).shape())); + output_mut_metas.at(i).set_stride(stride); + } + mut_output_tensor_metas->at(i) = SymbolOf(output_mut_metas.at(i)); + } + return std::shared_ptr(std::move(result)); +} + +Maybe LocalTensorInferCache::GetOrInfer( + const LocalTensorMetaInferArgs& infer_args) { + if (ThreadLocalEnvBool()) { + auto iter = cache_.find(infer_args); + if (iter == cache_.end()) { + const auto& user_op_expr = user_op_expr_.lock(); + CHECK_OR_RETURN(static_cast(user_op_expr)); // NOLINT + const auto& output_tensor_metas = JUST(Infer(*user_op_expr, infer_args)); + iter = cache_.emplace(infer_args, output_tensor_metas).first; + } + return iter->second; + } else { + const auto& user_op_expr = user_op_expr_.lock(); + return JUST(Infer(*user_op_expr, infer_args)); + } +} + +} // namespace one +} // namespace oneflow diff --git a/oneflow/core/framework/local_tensor_infer_cache.h b/oneflow/core/framework/local_tensor_infer_cache.h new file mode 100644 index 00000000000..534278a2da5 --- /dev/null +++ b/oneflow/core/framework/local_tensor_infer_cache.h @@ -0,0 +1,124 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_CORE_FRAMEWORK_LOCAL_TENSOR_INFER_CACHE_H_ +#define ONEFLOW_CORE_FRAMEWORK_LOCAL_TENSOR_INFER_CACHE_H_ + +#include "oneflow/core/common/symbol.h" +#include "oneflow/core/common/maybe.h" +#include "oneflow/core/common/small_vector.h" +#include "oneflow/core/common/op_args_reserved_size.h" +#include "oneflow/core/framework/attr_map.h" +#include "oneflow/core/framework/device.h" +#include "oneflow/core/framework/stream.h" +#include "oneflow/core/framework/tensor_meta.h" + +namespace oneflow { + +class Device; + +namespace one { + +template +using OpArgsVector = small_vector; + +class TensorTuple; +class UserOpExpr; + +class LocalTensorMetaInferArgs final { + public: + LocalTensorMetaInferArgs() = default; + LocalTensorMetaInferArgs(const LocalTensorMetaInferArgs&) = default; + LocalTensorMetaInferArgs(LocalTensorMetaInferArgs&&) = default; + ~LocalTensorMetaInferArgs() = default; + + const OpArgsVector>& input_local_tensor_metas() const { + return input_local_tensor_metas_; + } + const AttrMap& attrs() const { return attrs_; } + + const Symbol& default_device() const { return default_device_; } + + size_t hash_value() const; + + bool operator==(const LocalTensorMetaInferArgs& other) const; + + Maybe Init(const AttrMap& attrs, Symbol default_device, + const TensorTuple& input_tensors); + + private: + Maybe InitInputLocalTensorMetas(const TensorTuple& input_tensors); + + AttrMap attrs_; + Symbol default_device_; + OpArgsVector> input_local_tensor_metas_; +}; + +} // namespace one +} // namespace oneflow + +namespace std { + +template<> +struct hash final { + size_t operator()(const oneflow::one::LocalTensorMetaInferArgs& val) const { + return val.hash_value(); + } +}; + +} // namespace std + +namespace oneflow { +namespace one { + +class LocalTensorInferResult final { + public: + LocalTensorInferResult(size_t output_size) : output_tensor_metas_(output_size) {} + LocalTensorInferResult(const LocalTensorInferResult&) = delete; + LocalTensorInferResult(LocalTensorInferResult&&) = delete; + ~LocalTensorInferResult() = default; + + const OpArgsVector>& output_tensor_metas() const { + return output_tensor_metas_; + } + OpArgsVector>* mut_output_tensor_metas() { return &output_tensor_metas_; } + + const Symbol& stream() const { return stream_; } + void set_stream(const Symbol& stream) { stream_ = stream; } + + private: + OpArgsVector> output_tensor_metas_; + Symbol stream_; +}; + +class LocalTensorInferCache final { + public: + LocalTensorInferCache(const std::shared_ptr& user_op_expr) + : user_op_expr_(user_op_expr) {} + + Maybe GetOrInfer(const LocalTensorMetaInferArgs& infer_args); + + private: + static Maybe Infer(const UserOpExpr& user_op_expr, + const LocalTensorMetaInferArgs& infer_args); + + std::weak_ptr user_op_expr_; + HashMap> cache_; +}; + +} // namespace one +} // namespace oneflow + +#endif // ONEFLOW_CORE_FRAMEWORK_LOCAL_TENSOR_INFER_CACHE_H_ diff --git a/oneflow/core/framework/op_expr.cpp b/oneflow/core/framework/op_expr.cpp index 47c5a1d0d79..13113237061 100644 --- a/oneflow/core/framework/op_expr.cpp +++ b/oneflow/core/framework/op_expr.cpp @@ -22,6 +22,7 @@ limitations under the License. #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_interpreter/dispatch_frame.h" #include "oneflow/core/framework/user_op_registry_manager.h" +#include "oneflow/core/framework/local_tensor_infer_cache.h" #include "oneflow/core/framework/global_tensor_infer_cache.h" #include "oneflow/core/operator/op_conf.pb.h" #include "oneflow/user/kernels/stateful_opkernel.h" @@ -457,6 +458,7 @@ Maybe UserOpExpr::Init(const std::shared_ptr& self) { if (registry->device_and_stream_infer_fn) { device_and_stream_infer_fn_ = registry->device_and_stream_infer_fn; } + local_tensor_infer_cache_.reset(new LocalTensorInferCache(self)); global_tensor_infer_cache_.reset(new GlobalTensorInferCache(self)); return Maybe::Ok(); } diff --git a/oneflow/core/framework/op_expr.h b/oneflow/core/framework/op_expr.h index d2072249388..13a7a7a0a07 100644 --- a/oneflow/core/framework/op_expr.h +++ b/oneflow/core/framework/op_expr.h @@ -126,6 +126,7 @@ class BuiltinOpExprImpl : public BuiltinOpExpr { }; class StatefulOpKernel; +class LocalTensorInferCache; class GlobalTensorInferCache; class UserOpExpr final : public BuiltinOpExprImpl { @@ -159,6 +160,9 @@ class UserOpExpr final : public BuiltinOpExprImpl { const std::function& TensorMeta4OutputIndex) const; Maybe> InferDeviceAndStream(const AttrMap& attrs, const TensorTuple& inputs, TensorTuple* outputs) const; + LocalTensorInferCache* mut_local_tensor_infer_cache() const { + return local_tensor_infer_cache_.get(); + } GlobalTensorInferCache* mut_global_tensor_infer_cache() const { return global_tensor_infer_cache_.get(); } @@ -173,6 +177,7 @@ class UserOpExpr final : public BuiltinOpExprImpl { user_op::DataTypeInferFn dtype_infer_fn_; user_op::DeviceAndStreamInferFn device_and_stream_infer_fn_; mutable HashMap, std::shared_ptr> stream2kernel_; + std::shared_ptr local_tensor_infer_cache_; std::shared_ptr global_tensor_infer_cache_; }; diff --git a/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp b/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp index 71941c92eca..635e9889ea9 100644 --- a/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp +++ b/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp @@ -26,6 +26,7 @@ limitations under the License. #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor_name_scope.h" #include "oneflow/core/framework/tensor_tuple.h" +#include "oneflow/core/framework/local_tensor_infer_cache.h" #include "oneflow/core/common/stride.h" #include "oneflow/core/memory/memory_case_util.h" #include "oneflow/core/operator/operator.h" @@ -47,9 +48,19 @@ namespace one { namespace { -Maybe> GetDefaultDevice(const OpExprInterpContext& ctx) { - if (ctx.device.has_value()) { return JUST(ctx.device); } - return Device::New("cpu", 0); +Maybe> RawGetDefaultCpuDevice() { return Device::New("cpu", 0); } + +constexpr auto* GetDefaultCpuDevice = DECORATE(&RawGetDefaultCpuDevice, ThreadLocal); + +Maybe> GetDefaultDevice(const TensorTuple& inputs, const OpExprInterpContext& ctx) { + if (inputs.empty()) { + if (ctx.device.has_value()) { + return JUST(ctx.device); + } else { + return GetDefaultCpuDevice(); + } + } + return JUST(inputs.at(0)->device()); } Maybe TensorImpl4Tensor(const std::shared_ptr& tensor) { @@ -57,105 +68,37 @@ Maybe TensorImpl4Tensor(const std::shared_ptr& te return tensor->mut_eager_local_tensor_impl(); } -class MutLocalTensorMeta : public TensorMeta { // NOLINT - public: - MutLocalTensorMeta() - : TensorMeta(std::make_shared(), std::make_shared(), - kInvalidDataType) {} - MutLocalTensorMeta(const MutLocalTensorMeta&) = default; - MutLocalTensorMeta(MutLocalTensorMeta&&) = default; - ~MutLocalTensorMeta() override = default; -}; - -std::vector* ThreadLocalDefaultOutputMutTensorMetas(int64_t size) { - static thread_local std::vector struct_vec; - static thread_local std::vector ptr_vec; - struct_vec.resize(size); - ptr_vec.resize(size); - if (size == 1) { - ptr_vec.at(0) = &struct_vec.at(0); // unfold loop - } else if (size == 2) { - ptr_vec.at(0) = &struct_vec.at(0); // unfold loop - ptr_vec.at(1) = &struct_vec.at(1); // unfold loop - } else { - for (int i = 0; i < size; ++i) { ptr_vec.at(i) = &struct_vec.at(i); } - } - return &ptr_vec; -} - } // namespace Maybe NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& inputs, - const Symbol& default_device, TensorTuple* outputs, - const OpExprInterpContext& ctx) { + TensorTuple* outputs, const OpExprInterpContext& ctx) { OF_PROFILER_RANGE_GUARD("NaiveInterpret"); - OF_PROFILER_RANGE_PUSH("init inputs"); - const auto& attrs = ctx.attrs; + CHECK_EQ_OR_RETURN(outputs->size(), user_op_expr.output_size()); // NOLINT + Symbol default_device = JUST(GetDefaultDevice(inputs, ctx)); + const std::shared_ptr result = + JUST([&]() -> Maybe { + LocalTensorMetaInferArgs infer_args; + JUST(infer_args.Init(ctx.attrs, default_device, inputs)); + return JUST(user_op_expr.mut_local_tensor_infer_cache()->GetOrInfer(infer_args)); + }()); + vm::EagerBlobObjectList input_eager_blob_objects(inputs.size()); for (int i = 0; i < inputs.size(); i++) { - const auto& input_device = JUST(inputs.at(i)->device()); - if (i > 0) { - CHECK_OR_RETURN(default_device == input_device) - << Error::RuntimeError() - << "Expected all tensors to be on the same device, but found at least two devices, " - << default_device->ToString() << " (positional 0) and " << input_device->ToString() - << " (positional " << i << ")!"; - } input_eager_blob_objects.at(i) = JUST(inputs.at(i)->eager_blob_object()); } - OF_PROFILER_RANGE_POP(); - OF_PROFILER_RANGE_PUSH("init outputs"); + + const auto& output_tensor_metas = result->output_tensor_metas(); vm::EagerBlobObjectList output_eager_blob_objects(outputs->size()); - auto* output_tensor_metas = ThreadLocalDefaultOutputMutTensorMetas(outputs->size()); + for (int i = 0; i < outputs->size(); i++) { if (!outputs->at(i)) { - const auto& tensor_impl = std::make_shared(); - (*outputs)[i] = std::make_shared(tensor_impl); - output_tensor_metas->at(i) = tensor_impl->mut_tensor_meta(); - } else { - bool has_eager_blob_object = JUST(outputs->at(i)->has_eager_blob_object()); - CHECK_OR_RETURN(has_eager_blob_object); - output_eager_blob_objects.at(i) = JUST(outputs->at(i)->eager_blob_object()); - } - } - Symbol stream; - - OF_PROFILER_RANGE_POP(); - OF_PROFILER_RANGE_PUSH("infer devices"); - // Infer devices - if (!user_op_expr.has_device_and_stream_infer_fn()) { - stream = JUST(GetDefaultStreamByDevice(default_device)); - for (int i = 0; i < outputs->size(); i++) { - auto* tensor_impl = JUST(TensorImpl4Tensor(outputs->at(i))); - *JUST(tensor_impl->mut_device()) = default_device; - } - } else { - stream = JUST(user_op_expr.InferDeviceAndStream(attrs, inputs, outputs)); - } - - OF_PROFILER_RANGE_POP(); - OF_PROFILER_RANGE_PUSH("infer shapes and dtypes"); - // Infer shapes and dtypes - const auto& device_tag = stream->device()->type(); - JUST(user_op_expr.InferPhysicalTensorDesc( - attrs, device_tag, - [&](int32_t i) -> const TensorMeta* { - return CHECK_JUST(TensorImpl4Tensor(inputs[i]))->mut_tensor_meta(); - }, - [&](int32_t i) -> TensorMeta* { - // using thread_local TensorMeta pointer if inplace. - // using tensor_impl TensorMeta pointer if not inplace. - return output_tensor_metas->at(i); - })); - - OF_PROFILER_RANGE_POP(); - OF_PROFILER_RANGE_PUSH("init output eager_blob_objects"); - for (int i = 0; i < output_eager_blob_objects.size(); i++) { - auto* tensor_impl = JUST(TensorImpl4Tensor(outputs->at(i))); - if (!output_eager_blob_objects.at(i)) { // NOTE: if op support stride(non-contiguous input), then output tensor's stride // should be inferred in InferLogicalTensorDesc. // otherwise, it will be set here(according to shape). + // Note: symbol.shared_from_symbol() cannot be used here because set_stride happens in the + // next step. + std::shared_ptr tensor_impl = std::make_shared( + std::make_shared(*output_tensor_metas.at(i)), false, false); if (!JUST(user_op_expr.SupportNonContiguous())) { std::shared_ptr stride(new Stride(*tensor_impl->shape())); tensor_impl->mut_tensor_meta()->set_stride(stride); @@ -163,25 +106,29 @@ Maybe NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& in const auto& dep_object = NewLocalDepObject(); JUST(tensor_impl->InitEagerBlobObject(dep_object)); output_eager_blob_objects.at(i) = JUST(tensor_impl->eager_blob_object()); + (*outputs)[i] = std::make_shared(tensor_impl); } else { + auto* tensor_impl = JUST(TensorImpl4Tensor(outputs->at(i))); // output i is inplaced. - // check thread_local TensorMeta and tensor_impl TensorMeta. - CHECK_OR_RETURN(tensor_impl->tensor_meta()->shape() == output_tensor_metas->at(i)->shape()); - // TODO:(thread_local TensorMeta set stride then check) + // check TensorMeta of infer result and TensorMeta of output i. + CHECK_OR_RETURN(tensor_impl->tensor_meta()->shape() // NOLINT + == output_tensor_metas.at(i)->shape()); // NOLINT + CHECK_OR_RETURN(tensor_impl->tensor_meta()->dtype() // NOLINT + == output_tensor_metas.at(i)->dtype()); // NOLINT + bool has_eager_blob_object = JUST(outputs->at(i)->has_eager_blob_object()); + CHECK_OR_RETURN(has_eager_blob_object); // NOLINT + output_eager_blob_objects.at(i) = JUST(outputs->at(i)->eager_blob_object()); + // TODO(zhaoluyang):(thread_local TensorMeta set stride then check) // CHECK_OR_RETURN(tensor_impl->tensor_meta()->stride() == // output_tensor_metas->at(i)->stride()); - CHECK_OR_RETURN(tensor_impl->tensor_meta()->dtype() == output_tensor_metas->at(i)->dtype()); } } - OF_PROFILER_RANGE_POP(); - OF_PROFILER_RANGE_PUSH("init opkernel"); - const auto& kernel = JUST(user_op_expr.MutKernel4Stream(stream)); - OF_PROFILER_RANGE_POP(); - OF_PROFILER_RANGE_PUSH("PhysicalRun"); + const auto& kernel = JUST(user_op_expr.MutKernel4Stream(result->stream())); + JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe { return builder->Call(kernel, std::move(input_eager_blob_objects), - std::move(output_eager_blob_objects), ctx, stream); + std::move(output_eager_blob_objects), ctx, result->stream()); })); for (int64_t index : kernel->output_tuple_indexes4mut2_obns()) { const auto* tensor_impl = JUST(TensorImpl4Tensor(outputs->at(index))); @@ -192,20 +139,8 @@ Maybe NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& in })); JUST(btb->WaitUntilCntEqualZero(VirtualMachine::GetPredicatorNoMoreInstructionsFinished())); } - OF_PROFILER_RANGE_POP(); - return Maybe::Ok(); -} -static Maybe NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& inputs, - TensorTuple* outputs, const OpExprInterpContext& ctx) { - CHECK_EQ_OR_RETURN(outputs->size(), user_op_expr.output_size()); - Symbol default_device; - if (inputs.empty()) { - default_device = JUST(GetDefaultDevice(ctx)); - } else { - default_device = JUST(inputs.at(0)->device()); - } - return NaiveInterpret(user_op_expr, inputs, default_device, outputs, ctx); + return Maybe::Ok(); } Maybe EagerLocalInterpreter::ApplyImpl(const UserOpExpr& op_expr, const TensorTuple& inputs, diff --git a/oneflow/core/framework/tensor_meta.h b/oneflow/core/framework/tensor_meta.h index a8de6998828..1316706bba9 100644 --- a/oneflow/core/framework/tensor_meta.h +++ b/oneflow/core/framework/tensor_meta.h @@ -42,7 +42,11 @@ class TensorMeta : public user_op::TensorDesc { TensorMeta(const std::shared_ptr& shape, const std::shared_ptr& stride, DataType dtype) : shape_(shape), stride_(stride), data_type_(dtype), is_dynamic_(false) {} - TensorMeta(const TensorMeta&) = default; + TensorMeta(const TensorMeta& other) + : shape_(std::make_shared(*other.shape_)), + stride_(std::make_shared(*other.stride_)), + data_type_(other.data_type_), + is_dynamic_(other.is_dynamic_) {} TensorMeta(TensorMeta&&) = default; virtual ~TensorMeta() = default; @@ -66,6 +70,15 @@ class TensorMeta : public user_op::TensorDesc { bool* mut_is_dynamic() override { return &is_dynamic_; } void set_is_dynamic(bool val) override { is_dynamic_ = val; } + protected: + TensorMeta& operator=(const TensorMeta& other) { + this->shape_ = std::make_shared(*other.shape_); + this->stride_ = std::make_shared(*other.stride_); + this->data_type_ = other.data_type_; + this->is_dynamic_ = other.is_dynamic_; + return *this; + } + private: std::shared_ptr shape_; std::shared_ptr stride_; @@ -77,6 +90,7 @@ class LocalTensorMeta : public TensorMeta { public: // uninitialized LocalTensorMeta. LocalTensorMeta(); + LocalTensorMeta(const LocalTensorMeta&) = default; LocalTensorMeta(const std::shared_ptr& shape, DataType dtype, Symbol device); LocalTensorMeta(const std::shared_ptr& shape, const std::shared_ptr& stride, DataType dtype, @@ -92,6 +106,8 @@ class LocalTensorMeta : public TensorMeta { bool operator==(const LocalTensorMeta& other) const; size_t CalcHashValue() const; + LocalTensorMeta& operator=(const LocalTensorMeta& other) = default; + private: Symbol device_; int64_t storage_offset_; @@ -127,6 +143,13 @@ class GlobalTensorMeta : public TensorMeta { namespace std { +template<> +struct hash final { + size_t operator()(const oneflow::one::LocalTensorMeta& local_tensor_meta) const { + return local_tensor_meta.CalcHashValue(); + } +}; + template<> struct hash final { size_t operator()(const oneflow::one::GlobalTensorMeta& global_tensor_meta) const {